From fb3fe56bafac7f37781c4cd628a452584d133bd7 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 31 Jan 2025 09:21:48 -0800 Subject: [PATCH 01/46] Bump github.com/golang/glog (#1115) Bumps the go_modules group with 1 update in the /codelab directory: [github.com/golang/glog](https://github.com/golang/glog). Updates `github.com/golang/glog` from 1.0.0 to 1.2.4 - [Release notes](https://github.com/golang/glog/releases) - [Commits](https://github.com/golang/glog/compare/v1.0.0...v1.2.4) --- updated-dependencies: - dependency-name: github.com/golang/glog dependency-type: direct:production dependency-group: go_modules ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- codelab/go.mod | 6 ++++-- codelab/go.sum | 14 ++++++++------ 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/codelab/go.mod b/codelab/go.mod index 9f0b67e51..fb770bf22 100644 --- a/codelab/go.mod +++ b/codelab/go.mod @@ -1,15 +1,17 @@ module github.com/google/cel-go/codelab go 1.21 +toolchain go1.22.5 require ( - github.com/golang/glog v1.0.0 + github.com/golang/glog v1.2.4 github.com/google/cel-go v0.21.0 - google.golang.org/genproto/googleapis/rpc v0.0.0-20240823204242-4ba0660f739c + google.golang.org/genproto/googleapis/rpc v0.0.0-20240826202546-f6391c0de4c7 google.golang.org/protobuf v1.34.2 ) require ( + cel.dev/expr v0.19.1 // indirect github.com/antlr4-go/antlr/v4 v4.13.0 // indirect github.com/stoewer/go-strcase v1.2.0 // indirect golang.org/x/exp v0.0.0-20230515195305-f3d0a9c9a5cc // indirect diff --git a/codelab/go.sum b/codelab/go.sum index ec05b8574..f956c29d3 100644 --- a/codelab/go.sum +++ b/codelab/go.sum @@ -1,11 +1,13 @@ +cel.dev/expr v0.19.1 h1:NciYrtDRIR0lNCnH1LFJegdjspNx9fI59O7TWcua/W4= +cel.dev/expr v0.19.1/go.mod h1:MrpN08Q+lEBs+bGYdLxxHkZoUSsCp0nSKTs0nTymJgw= github.com/antlr4-go/antlr/v4 v4.13.0 h1:lxCg3LAv+EUK6t1i0y1V6/SLeUi0eKEKdhQAlS8TVTI= github.com/antlr4-go/antlr/v4 v4.13.0/go.mod h1:pfChB/xh/Unjila75QW7+VU4TSnWnnk9UTnmpPaOR2g= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/golang/glog v1.0.0 h1:nfP3RFugxnNRyKgeWd4oI1nYvXpxrx8ck8ZrcizshdQ= -github.com/golang/glog v1.0.0/go.mod h1:EWib/APOK0SL3dFbYqvxE3UYd8E6s1ouQ7iEp/0LWV4= -github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= -github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/golang/glog v1.2.4 h1:CNNw5U8lSiiBk7druxtSHHTsRWcxKoac6kZKm2peBBc= +github.com/golang/glog v1.2.4/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stoewer/go-strcase v1.2.0 h1:Z2iHWqGXH00XYgqDmNgQbIBxf3wrNq0F3feEy0ainaU= @@ -19,8 +21,8 @@ golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= google.golang.org/genproto/googleapis/api v0.0.0-20240826202546-f6391c0de4c7 h1:YcyjlL1PRr2Q17/I0dPk2JmYS5CDXfcdb2Z3YRioEbw= google.golang.org/genproto/googleapis/api v0.0.0-20240826202546-f6391c0de4c7/go.mod h1:OCdP9MfskevB/rbYvHTsXTtKC+3bHWajPdoKgjcYkfo= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240823204242-4ba0660f739c h1:Kqjm4WpoWvwhMPcrAczoTyMySQmYa9Wy2iL6Con4zn8= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240823204242-4ba0660f739c/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240826202546-f6391c0de4c7 h1:2035KHhUv+EpyB+hWgJnaWKJOdX1E95w2S8Rr4uWKTs= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240826202546-f6391c0de4c7/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU= google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= From 1bf2472a30a005c3c3f37c6b58d0576ecfd5e478 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Fri, 31 Jan 2025 13:08:08 -0800 Subject: [PATCH 02/46] Minor update on cost order (#1119) * slight reordering in cost computation to preserve legacy compatibility * Additional tests to attempt to catch ordering issues with cost estimation --- checker/cost.go | 6 ++-- checker/cost_test.go | 65 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 3 deletions(-) diff --git a/checker/cost.go b/checker/cost.go index b9cd8a2ed..59be751c9 100644 --- a/checker/cost.go +++ b/checker/cost.go @@ -930,6 +930,9 @@ func (c *coster) computeSize(e ast.Expr) *SizeEstimate { if size, ok := c.computedSizes[e.ID()]; ok { return &size } + if size := computeExprSize(e); size != nil { + return size + } // Ensure size estimates are computed first as users may choose to override the costs that // CEL would otherwise ascribe to the type. node := astNode{expr: e, path: c.getPath(e), t: c.getType(e)} @@ -938,9 +941,6 @@ func (c *coster) computeSize(e ast.Expr) *SizeEstimate { c.computedSizes[e.ID()] = *size return size } - if size := computeExprSize(e); size != nil { - return size - } if size := computeTypeSize(c.getType(e)); size != nil { return size } diff --git a/checker/cost_test.go b/checker/cost_test.go index 2bec0e94a..f667ebe0e 100644 --- a/checker/cost_test.go +++ b/checker/cost_test.go @@ -715,6 +715,31 @@ func TestCost(t *testing.T) { expr: `self.val1 == 1.0`, wanted: FixedCostEstimate(3), }, + { + name: "bytes list max", + expr: "[bytes('012345678901'), bytes('012345678901'), bytes('012345678901'), bytes('012345678901'), bytes('012345678901')].max()", + options: []CostOption{ + OverloadCostEstimate("list_bytes_max", + func(estimator CostEstimator, target *AstNode, args []AstNode) *CallEstimate { + if target != nil { + // Charge 1 cost for comparing each element in the list + elCost := CostEstimate{Min: 1, Max: 1} + // If the list contains strings or bytes, add the cost of traversing all the strings/bytes as a way + // of estimating the additional comparison cost. + if elNode := listElementNode(*target); elNode != nil { + k := elNode.Type().Kind() + if k == types.StringKind || k == types.BytesKind { + sz := sizeEstimate(estimator, elNode) + elCost = elCost.Add(sz.MultiplyByCostFactor(common.StringTraversalCostFactor)) + } + return &CallEstimate{CostEstimate: sizeEstimate(estimator, *target).MultiplyByCost(elCost)} + } + } + return nil + }), + }, + wanted: CostEstimate{Min: 25, Max: 35}, + }, } for _, tst := range cases { @@ -745,6 +770,14 @@ func TestCost(t *testing.T) { if err != nil { t.Fatalf("environment creation error: %v", err) } + maxFunc, _ := decls.NewFunction("max", + decls.MemberOverload("list_bytes_max", + []*types.Type{types.NewListType(types.BytesType)}, + types.BytesType)) + err = e.AddFunctions(maxFunc) + if err != nil { + t.Fatalf("environment creation error: %v", err) + } err = e.AddIdents(tc.vars...) if err != nil { t.Fatalf("environment creation error: %s\n", err) @@ -773,6 +806,9 @@ func (tc testCostEstimator) EstimateSize(element AstNode) *SizeEstimate { if l, ok := tc.hints[strings.Join(element.Path(), ".")]; ok { return &SizeEstimate{Min: 0, Max: l} } + if element.Type() == types.BytesType { + return &SizeEstimate{Min: 0, Max: 12} + } return nil } @@ -793,3 +829,32 @@ func estimateSize(estimator CostEstimator, node AstNode) SizeEstimate { } return SizeEstimate{Min: 0, Max: math.MaxUint64} } + +func listElementNode(list AstNode) AstNode { + if params := list.Type().Parameters(); len(params) > 0 { + lt := params[0] + nodePath := list.Path() + if nodePath != nil { + // Provide path if we have it so that a OpenAPIv3 maxLength validation can be looked up, if it exists + // for this node. + path := make([]string, len(nodePath)+1) + copy(path, nodePath) + path[len(nodePath)] = "@items" + return &astNode{path: path, t: lt, expr: nil} + } else { + // Provide just the type if no path is available so that worst case size can be looked up based on type. + return &astNode{t: lt, expr: nil} + } + } + return nil +} + +func sizeEstimate(estimator CostEstimator, t AstNode) SizeEstimate { + if sz := t.ComputedSize(); sz != nil { + return *sz + } + if sz := estimator.EstimateSize(t); sz != nil { + return *sz + } + return SizeEstimate{Min: 0, Max: math.MaxUint64} +} From 2a85bb6d62da8be539c01cafab98e548d8a56e08 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Mon, 3 Feb 2025 12:58:14 -0800 Subject: [PATCH 03/46] Helper methods for subsetting function declaration overloads (#1120) --- cel/decls.go | 32 ++++++++ cel/decls_test.go | 182 ++++++++++++++++++++++++++++++++++++++++++ common/decls/decls.go | 54 +++++++++++++ 3 files changed, 268 insertions(+) diff --git a/cel/decls.go b/cel/decls.go index 418806021..7a2bd9b7c 100644 --- a/cel/decls.go +++ b/cel/decls.go @@ -194,6 +194,38 @@ func Function(name string, opts ...FunctionOpt) EnvOption { } } +// OverloadSelector selects an overload associated with a given function when it returns true. +// +// Used in combination with the FunctionDecl.Subset method. +type OverloadSelector = decls.OverloadSelector + +// IncludeOverloads defines an OverloadSelector which allow-lists a set of overloads by their ids. +func IncludeOverloads(overloadIDs ...string) OverloadSelector { + return decls.IncludeOverloads(overloadIDs...) +} + +// ExcludeOverloads defines an OverloadSelector which deny-lists a set of overloads by their ids. +func ExcludeOverloads(overloadIDs ...string) OverloadSelector { + return decls.ExcludeOverloads(overloadIDs...) +} + +// FunctionDecls provides one or more fully formed function declaration to be added to the environment. +func FunctionDecls(funcs ...*decls.FunctionDecl) EnvOption { + return func(e *Env) (*Env, error) { + var err error + for _, fn := range funcs { + if existing, found := e.functions[fn.Name()]; found { + fn, err = existing.Merge(fn) + if err != nil { + return nil, err + } + } + e.functions[fn.Name()] = fn + } + return e, nil + } +} + // FunctionOpt defines a functional option for configuring a function declaration. type FunctionOpt = decls.FunctionOpt diff --git a/cel/decls_test.go b/cel/decls_test.go index f15862fac..9024d74db 100644 --- a/cel/decls_test.go +++ b/cel/decls_test.go @@ -26,6 +26,7 @@ import ( "github.com/google/cel-go/common/functions" "github.com/google/cel-go/common/operators" "github.com/google/cel-go/common/overloads" + "github.com/google/cel-go/common/stdlib" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/common/types/traits" @@ -780,6 +781,187 @@ func TestExprDeclToDeclarationInvalid(t *testing.T) { } } +func TestFunctionDeclExcludeOverloads(t *testing.T) { + funcs := []*decls.FunctionDecl{} + for _, fn := range stdlib.Functions() { + if fn.Name() == operators.Add { + fn = fn.Subset(ExcludeOverloads(overloads.AddList, overloads.AddBytes, overloads.AddString)) + } + funcs = append(funcs, fn) + } + env, err := NewCustomEnv(FunctionDecls(funcs...)) + if err != nil { + t.Fatalf("NewCustomEnv() failed: %v", err) + } + + successTests := []struct { + name string + expr string + want ref.Val + }{ + { + name: "ints", + expr: "1 + 1", + want: types.Int(2), + }, + { + name: "doubles", + expr: "1.5 + 1.5", + want: types.Double(3.0), + }, + { + name: "uints", + expr: "1u + 2u", + want: types.Uint(3), + }, + { + name: "timestamp plus duration", + expr: "timestamp('2001-01-01T00:00:00Z') + duration('1h') == timestamp('2001-01-01T01:00:00Z')", + want: types.True, + }, + { + name: "durations", + expr: "duration('1h') + duration('1m') == duration('1h1m')", + want: types.True, + }, + } + for _, tst := range successTests { + tc := tst + t.Run(tc.name, func(t *testing.T) { + ast, iss := env.Compile(tc.expr) + if iss.Err() != nil { + t.Fatalf("env.Compile() failed: %v", iss.Err()) + } + prg, err := env.Program(ast) + if err != nil { + t.Fatalf("env.Program() failed: %v", err) + } + out, _, err := prg.Eval(NoVars()) + if err != nil { + t.Fatalf("prg.Eval() errored: %v", err) + } + if out.Equal(tc.want) != types.True { + t.Errorf("Eval() got %v, wanted %v", out, tc.want) + } + }) + } + failureTests := []struct { + name string + expr string + }{ + { + name: "strings", + expr: "'a' + 'b'", + }, + { + name: "bytes", + expr: "b'123' + b'456'", + }, + { + name: "lists", + expr: "[1] + [2, 3]", + }, + } + for _, tst := range failureTests { + tc := tst + t.Run(tc.name, func(t *testing.T) { + _, iss := env.Compile(tc.expr) + if iss.Err() == nil { + t.Error("env.Compile() got ast, wanted error") + } + }) + } +} + +func TestFunctionDeclIncludeOverloads(t *testing.T) { + funcs := []*decls.FunctionDecl{} + for _, fn := range stdlib.Functions() { + if fn.Name() == operators.Add { + fn = fn.Subset(IncludeOverloads(overloads.AddInt64, overloads.AddDouble)) + } + funcs = append(funcs, fn) + } + env, err := NewCustomEnv(FunctionDecls(funcs...)) + if err != nil { + t.Fatalf("NewCustomEnv() failed: %v", err) + } + + successTests := []struct { + name string + expr string + want ref.Val + }{ + { + name: "ints", + expr: "1 + 1", + want: types.Int(2), + }, + { + name: "doubles", + expr: "1.5 + 1.5", + want: types.Double(3.0), + }, + } + for _, tst := range successTests { + tc := tst + t.Run(tc.name, func(t *testing.T) { + ast, iss := env.Compile(tc.expr) + if iss.Err() != nil { + t.Fatalf("env.Compile() failed: %v", iss.Err()) + } + prg, err := env.Program(ast) + if err != nil { + t.Fatalf("env.Program() failed: %v", err) + } + out, _, err := prg.Eval(NoVars()) + if err != nil { + t.Fatalf("prg.Eval() errored: %v", err) + } + if out.Equal(tc.want) != types.True { + t.Errorf("Eval() got %v, wanted %v", out, tc.want) + } + }) + } + failureTests := []struct { + name string + expr string + }{ + { + name: "strings", + expr: "'a' + 'b'", + }, + { + name: "bytes", + expr: "b'123' + b'456'", + }, + { + name: "lists", + expr: "[1] + [2, 3]", + }, + { + name: "uints", + expr: "1u + 2u", + }, + { + name: "timestamp plus duration", + expr: "timestamp('2001-01-01T00:00:00Z') + duration('1h') == timestamp('2001-01-01T01:00:00Z')", + }, + { + name: "durations", + expr: "duration('1h') + duration('1m') == duration('1h1m')", + }, + } + for _, tst := range failureTests { + tc := tst + t.Run(tc.name, func(t *testing.T) { + _, iss := env.Compile(tc.expr) + if iss.Err() == nil { + t.Error("env.Compile() got ast, wanted error") + } + }) + } +} + func testParse(t testing.TB, env *Env, expr string, want any) { t.Helper() ast, iss := env.Parse(expr) diff --git a/common/decls/decls.go b/common/decls/decls.go index bfeb52c51..451a6c0d6 100644 --- a/common/decls/decls.go +++ b/common/decls/decls.go @@ -148,6 +148,60 @@ func (f *FunctionDecl) Merge(other *FunctionDecl) (*FunctionDecl, error) { return merged, nil } +// OverloadSelector selects an overload associated with a given function when it returns true. +// +// Used in combination with the Subset method. +type OverloadSelector func(overload *OverloadDecl) bool + +// IncludeOverloads defines an OverloadSelector which allow-lists a set of overloads by their ids. +func IncludeOverloads(overloadIDs ...string) OverloadSelector { + return func(overload *OverloadDecl) bool { + for _, oID := range overloadIDs { + if overload.id == oID { + return true + } + } + return false + } +} + +// ExcludeOverloads defines an OverloadSelector which deny-lists a set of overloads by their ids. +func ExcludeOverloads(overloadIDs ...string) OverloadSelector { + return func(overload *OverloadDecl) bool { + for _, oID := range overloadIDs { + if overload.id == oID { + return false + } + } + return true + } +} + +// Subset returns a new function declaration which contains only the overloads with the specified IDs. +func (f *FunctionDecl) Subset(selector OverloadSelector) *FunctionDecl { + if f == nil { + return nil + } + overloads := make(map[string]*OverloadDecl) + overloadOrdinals := make([]string, 0, len(f.overloadOrdinals)) + for _, oID := range f.overloadOrdinals { + overload := f.overloads[oID] + if selector(overload) { + overloads[oID] = overload + overloadOrdinals = append(overloadOrdinals, oID) + } + } + subset := &FunctionDecl{ + name: f.Name(), + overloads: overloads, + singleton: f.singleton, + disableTypeGuards: f.disableTypeGuards, + state: f.state, + overloadOrdinals: overloadOrdinals, + } + return subset +} + // AddOverload ensures that the new overload does not collide with an existing overload signature; // however, if the function signatures are identical, the implementation may be rewritten as its // difficult to compare functions by object identity. From 1ef45b2df9c8ad7ee88215fca390bbeb78b9d4ba Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Tue, 4 Feb 2025 13:22:20 -0800 Subject: [PATCH 04/46] Indicate that CEL is an official Google product (#1122) --- README.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/README.md b/README.md index 49bebf25e..5318e60dd 100644 --- a/README.md +++ b/README.md @@ -278,8 +278,6 @@ bazel test ... Released under the [Apache License](LICENSE). -Disclaimer: This is not an official Google product. - [1]: https://github.com/google/cel-spec [2]: https://groups.google.com/forum/#!forum/cel-go-discuss [3]: https://github.com/google/cel-cpp From e086729a7ed778497dfa34b4029366a0e7ea0a8d Mon Sep 17 00:00:00 2001 From: "zev.ac" Date: Wed, 5 Feb 2025 12:27:42 +0530 Subject: [PATCH 05/46] Add k8s custom policy tag handler for test (#1121) * Add k8s custom policy tag handler for test * Add copyright and remove redundant attribute from go_library target --- policy/BUILD.bazel | 1 + policy/helper_test.go | 81 +------------------------- policy/test_tag_handler_k8s.go | 100 +++++++++++++++++++++++++++++++++ 3 files changed, 102 insertions(+), 80 deletions(-) create mode 100644 policy/test_tag_handler_k8s.go diff --git a/policy/BUILD.bazel b/policy/BUILD.bazel index 5b63cb7b7..5fec59d78 100644 --- a/policy/BUILD.bazel +++ b/policy/BUILD.bazel @@ -28,6 +28,7 @@ go_library( "config.go", "parser.go", "source.go", + "test_tag_handler_k8s.go", ], importpath = "github.com/google/cel-go/policy", deps = [ diff --git a/policy/helper_test.go b/policy/helper_test.go index 396d919c1..f65b19c68 100644 --- a/policy/helper_test.go +++ b/policy/helper_test.go @@ -39,7 +39,7 @@ var ( { name: "k8s", parseOpts: []ParserOption{func(p *Parser) (*Parser, error) { - p.TagVisitor = k8sTagHandler() + p.TagVisitor = K8sTestTagHandler() return p, nil }}, expr: ` @@ -268,85 +268,6 @@ ERROR: testdata/errors_unreachable/policy.yaml:36:13: match creates unreachable } ) -func k8sTagHandler() TagVisitor { - return k8sAdmissionTagHandler{TagVisitor: DefaultTagVisitor()} -} - -type k8sAdmissionTagHandler struct { - TagVisitor -} - -func (k8sAdmissionTagHandler) PolicyTag(ctx ParserContext, id int64, tagName string, node *yaml.Node, policy *Policy) { - switch tagName { - case "kind": - policy.SetMetadata("kind", ctx.NewString(node).Value) - case "metadata": - m := k8sMetadata{} - if err := node.Decode(&m); err != nil { - ctx.ReportErrorAtID(id, "invalid yaml metadata node: %v, error: %w", node, err) - return - } - case "spec": - spec := ctx.ParseRule(ctx, policy, node) - policy.SetRule(spec) - default: - ctx.ReportErrorAtID(id, "unsupported policy tag: %s", tagName) - } -} - -func (k8sAdmissionTagHandler) RuleTag(ctx ParserContext, id int64, tagName string, node *yaml.Node, policy *Policy, r *Rule) { - switch tagName { - case "failurePolicy": - policy.SetMetadata(tagName, ctx.NewString(node).Value) - case "matchConstraints": - m := k8sMatchConstraints{} - if err := node.Decode(&m); err != nil { - ctx.ReportErrorAtID(id, "invalid yaml matchConstraints node: %v, error: %w", node, err) - return - } - case "validations": - id := ctx.CollectMetadata(node) - if node.LongTag() != "tag:yaml.org,2002:seq" { - ctx.ReportErrorAtID(id, "invalid 'validations' type, expected list got: %s", node.LongTag()) - return - } - for _, val := range node.Content { - r.AddMatch(ctx.ParseMatch(ctx, policy, val)) - } - default: - ctx.ReportErrorAtID(id, "unsupported rule tag: %s", tagName) - } -} - -func (k8sAdmissionTagHandler) MatchTag(ctx ParserContext, id int64, tagName string, node *yaml.Node, policy *Policy, m *Match) { - if m.Output().Value == "" { - m.SetOutput(ValueString{Value: "'invalid admission request'"}) - } - switch tagName { - case "expression": - // The K8s expression to validate must return false in order to generate a violation message. - condition := ctx.NewString(node) - condition.Value = "!(" + condition.Value + ")" - m.SetCondition(condition) - case "messageExpression": - m.SetOutput(ctx.NewString(node)) - } -} - -type k8sMetadata struct { - Name string `yaml:"name"` -} - -type k8sMatchConstraints struct { - ResourceRules []k8sResourceRule `yaml:"resourceRules"` -} - -type k8sResourceRule struct { - APIGroups []string `yaml:"apiGroups"` - APIVersions []string `yaml:"apiVersions"` - Operations []string `yaml:"operations"` -} - func readPolicy(t testing.TB, fileName string) *Source { t.Helper() policyBytes, err := os.ReadFile(fileName) diff --git a/policy/test_tag_handler_k8s.go b/policy/test_tag_handler_k8s.go new file mode 100644 index 000000000..54edc71ba --- /dev/null +++ b/policy/test_tag_handler_k8s.go @@ -0,0 +1,100 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package policy + +import ( + "gopkg.in/yaml.v3" +) + +// K8sTestTagHandler returns a TagVisitor which handles custom policy tags used in K8s policies. This is +// a helper function to be used in tests. +func K8sTestTagHandler() TagVisitor { + return k8sAdmissionTagHandler{TagVisitor: DefaultTagVisitor()} +} + +type k8sAdmissionTagHandler struct { + TagVisitor +} + +func (k8sAdmissionTagHandler) PolicyTag(ctx ParserContext, id int64, tagName string, node *yaml.Node, policy *Policy) { + switch tagName { + case "kind": + policy.SetMetadata("kind", ctx.NewString(node).Value) + case "metadata": + m := k8sMetadata{} + if err := node.Decode(&m); err != nil { + ctx.ReportErrorAtID(id, "invalid yaml metadata node: %v, error: %w", node, err) + return + } + case "spec": + spec := ctx.ParseRule(ctx, policy, node) + policy.SetRule(spec) + default: + ctx.ReportErrorAtID(id, "unsupported policy tag: %s", tagName) + } +} + +func (k8sAdmissionTagHandler) RuleTag(ctx ParserContext, id int64, tagName string, node *yaml.Node, policy *Policy, r *Rule) { + switch tagName { + case "failurePolicy": + policy.SetMetadata(tagName, ctx.NewString(node).Value) + case "matchConstraints": + m := k8sMatchConstraints{} + if err := node.Decode(&m); err != nil { + ctx.ReportErrorAtID(id, "invalid yaml matchConstraints node: %v, error: %w", node, err) + return + } + case "validations": + id := ctx.CollectMetadata(node) + if node.LongTag() != "tag:yaml.org,2002:seq" { + ctx.ReportErrorAtID(id, "invalid 'validations' type, expected list got: %s", node.LongTag()) + return + } + for _, val := range node.Content { + r.AddMatch(ctx.ParseMatch(ctx, policy, val)) + } + default: + ctx.ReportErrorAtID(id, "unsupported rule tag: %s", tagName) + } +} + +func (k8sAdmissionTagHandler) MatchTag(ctx ParserContext, id int64, tagName string, node *yaml.Node, policy *Policy, m *Match) { + if m.Output().Value == "" { + m.SetOutput(ValueString{Value: "'invalid admission request'"}) + } + switch tagName { + case "expression": + // The K8s expression to validate must return false in order to generate a violation message. + condition := ctx.NewString(node) + condition.Value = "!(" + condition.Value + ")" + m.SetCondition(condition) + case "messageExpression": + m.SetOutput(ctx.NewString(node)) + } +} + +type k8sMetadata struct { + Name string `yaml:"name"` +} + +type k8sMatchConstraints struct { + ResourceRules []k8sResourceRule `yaml:"resourceRules"` +} + +type k8sResourceRule struct { + APIGroups []string `yaml:"apiGroups"` + APIVersions []string `yaml:"apiVersions"` + Operations []string `yaml:"operations"` +} From c0532516cb14d45f8f383ae71146d5718c72669d Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Fri, 7 Feb 2025 14:59:47 -0800 Subject: [PATCH 06/46] Introduce cel package aliases for Activation (#1123) --- cel/env.go | 25 ++++++++++---------- cel/options.go | 8 +++---- cel/program.go | 56 ++++++++++++++++++++++++++++++++------------- ext/bindings.go | 8 +++---- repl/BUILD.bazel | 3 +-- repl/evaluator.go | 5 ++-- test/bench/bench.go | 4 ++-- 7 files changed, 65 insertions(+), 44 deletions(-) diff --git a/cel/env.go b/cel/env.go index 3bfe42899..e3de439de 100644 --- a/cel/env.go +++ b/cel/env.go @@ -502,31 +502,30 @@ func (e *Env) TypeProvider() ref.TypeProvider { return &interopLegacyTypeProvider{Provider: e.provider} } -// UnknownVars returns an interpreter.PartialActivation which marks all variables declared in the -// Env as unknown AttributePattern values. +// UnknownVars returns a PartialActivation which marks all variables declared in the Env as +// unknown AttributePattern values. // -// Note, the UnknownVars will behave the same as an interpreter.EmptyActivation unless the -// PartialAttributes option is provided as a ProgramOption. -func (e *Env) UnknownVars() interpreter.PartialActivation { +// Note, the UnknownVars will behave the same as an cel.NoVars() unless the PartialAttributes +// option is provided as a ProgramOption. +func (e *Env) UnknownVars() PartialActivation { act := interpreter.EmptyActivation() part, _ := PartialVars(act, e.computeUnknownVars(act)...) return part } -// PartialVars returns an interpreter.PartialActivation where all variables not in the input variable +// PartialVars returns a PartialActivation where all variables not in the input variable // set, but which have been configured in the environment, are marked as unknown. // -// The `vars` value may either be an interpreter.Activation or any valid input to the -// interpreter.NewActivation call. +// The `vars` value may either be an Activation or any valid input to the cel.NewActivation call. // // Note, this is equivalent to calling cel.PartialVars and manually configuring the set of unknown // variables. For more advanced use cases of partial state where portions of an object graph, rather // than top-level variables, are missing the PartialVars() method may be a more suitable choice. // -// Note, the PartialVars will behave the same as an interpreter.EmptyActivation unless the -// PartialAttributes option is provided as a ProgramOption. -func (e *Env) PartialVars(vars any) (interpreter.PartialActivation, error) { - act, err := interpreter.NewActivation(vars) +// Note, the PartialVars will behave the same as cel.NoVars() unless the PartialAttributes +// option is provided as a ProgramOption. +func (e *Env) PartialVars(vars any) (PartialActivation, error) { + act, err := NewActivation(vars) if err != nil { return nil, err } @@ -708,7 +707,7 @@ func (e *Env) maybeApplyFeature(feature int, option EnvOption) (*Env, error) { // computeUnknownVars determines a set of missing variables based on the input activation and the // environment's configured declaration set. -func (e *Env) computeUnknownVars(vars interpreter.Activation) []*interpreter.AttributePattern { +func (e *Env) computeUnknownVars(vars Activation) []*interpreter.AttributePattern { var unknownPatterns []*interpreter.AttributePattern for _, v := range e.variables { varName := v.Name() diff --git a/cel/options.go b/cel/options.go index 85f777e95..82b6c8d9f 100644 --- a/cel/options.go +++ b/cel/options.go @@ -401,10 +401,10 @@ func Functions(funcs ...*functions.Overload) ProgramOption { // variables with the same name provided to the Eval() call. If Globals is used in a Library with // a Lib EnvOption, vars may shadow variables provided by previously added libraries. // -// The vars value may either be an `interpreter.Activation` instance or a `map[string]any`. +// The vars value may either be an `cel.Activation` instance or a `map[string]any`. func Globals(vars any) ProgramOption { return func(p *prog) (*prog, error) { - defaultVars, err := interpreter.NewActivation(vars) + defaultVars, err := NewActivation(vars) if err != nil { return nil, err } @@ -588,7 +588,7 @@ func DeclareContextProto(descriptor protoreflect.MessageDescriptor) EnvOption { // // Consider using with `DeclareContextProto` to simplify variable type declarations and publishing when using // protocol buffers. -func ContextProtoVars(ctx proto.Message) (interpreter.Activation, error) { +func ContextProtoVars(ctx proto.Message) (Activation, error) { if ctx == nil || !ctx.ProtoReflect().IsValid() { return interpreter.EmptyActivation(), nil } @@ -612,7 +612,7 @@ func ContextProtoVars(ctx proto.Message) (interpreter.Activation, error) { } vars[field.TextName()] = fieldVal } - return interpreter.NewActivation(vars) + return NewActivation(vars) } // EnableMacroCallTracking ensures that call expressions which are replaced by macros diff --git a/cel/program.go b/cel/program.go index 49bd53783..c57932899 100644 --- a/cel/program.go +++ b/cel/program.go @@ -29,7 +29,7 @@ import ( type Program interface { // Eval returns the result of an evaluation of the Ast and environment against the input vars. // - // The vars value may either be an `interpreter.Activation` or a `map[string]any`. + // The vars value may either be an `Activation` or a `map[string]any`. // // If the `OptTrackState`, `OptTrackCost` or `OptExhaustiveEval` flags are used, the `details` response will // be non-nil. Given this caveat on `details`, the return state from evaluation will be: @@ -47,14 +47,39 @@ type Program interface { // to support cancellation and timeouts. This method must be used in conjunction with the // InterruptCheckFrequency() option for cancellation interrupts to be impact evaluation. // - // The vars value may either be an `interpreter.Activation` or `map[string]any`. + // The vars value may either be an `Activation` or `map[string]any`. // // The output contract for `ContextEval` is otherwise identical to the `Eval` method. ContextEval(context.Context, any) (ref.Val, *EvalDetails, error) } +// Activation used to resolve identifiers by name and references by id. +// +// An Activation is the primary mechanism by which a caller supplies input into a CEL program. +type Activation = interpreter.Activation + +// NewActivation returns an activation based on a map-based binding where the map keys are +// expected to be qualified names used with ResolveName calls. +// +// The input `bindings` may either be of type `Activation` or `map[string]any`. +// +// Lazy bindings may be supplied within the map-based input in either of the following forms: +// - func() any +// - func() ref.Val +// +// The output of the lazy binding will overwrite the variable reference in the internal map. +// +// Values which are not represented as ref.Val types on input may be adapted to a ref.Val using +// the types.Adapter configured in the environment. +func NewActivation(bindings any) (Activation, error) { + return interpreter.NewActivation(bindings) +} + +// PartialActivation extends the Activation interface with a set of UnknownAttributePatterns. +type PartialActivation = interpreter.PartialActivation + // NoVars returns an empty Activation. -func NoVars() interpreter.Activation { +func NoVars() Activation { return interpreter.EmptyActivation() } @@ -64,10 +89,9 @@ func NoVars() interpreter.Activation { // This method relies on manually configured sets of missing attribute patterns. For a method which // infers the missing variables from the input and the configured environment, use Env.PartialVars(). // -// The `vars` value may either be an interpreter.Activation or any valid input to the -// interpreter.NewActivation call. +// The `vars` value may either be an Activation or any valid input to the NewActivation call. func PartialVars(vars any, - unknowns ...*interpreter.AttributePattern) (interpreter.PartialActivation, error) { + unknowns ...*interpreter.AttributePattern) (PartialActivation, error) { return interpreter.NewPartialActivation(vars, unknowns...) } @@ -120,7 +144,7 @@ func (ed *EvalDetails) ActualCost() *uint64 { type prog struct { *Env evalOpts EvalOption - defaultVars interpreter.Activation + defaultVars Activation dispatcher interpreter.Dispatcher interpreter interpreter.Interpreter interruptCheckFrequency uint @@ -285,9 +309,9 @@ func (p *prog) Eval(input any) (v ref.Val, det *EvalDetails, err error) { } }() // Build a hierarchical activation if there are default vars set. - var vars interpreter.Activation + var vars Activation switch v := input.(type) { - case interpreter.Activation: + case Activation: vars = v case map[string]any: vars = activationPool.Setup(v) @@ -315,9 +339,9 @@ func (p *prog) ContextEval(ctx context.Context, input any) (ref.Val, *EvalDetail } // Configure the input, making sure to wrap Activation inputs in the special ctxActivation which // exposes the #interrupted variable and manages rate-limited checks of the ctx.Done() state. - var vars interpreter.Activation + var vars Activation switch v := input.(type) { - case interpreter.Activation: + case Activation: vars = ctxActivationPool.Setup(v, ctx.Done(), p.interruptCheckFrequency) defer ctxActivationPool.Put(vars) case map[string]any: @@ -414,7 +438,7 @@ func (gen *progGen) ContextEval(ctx context.Context, input any) (ref.Val, *EvalD } type ctxEvalActivation struct { - parent interpreter.Activation + parent Activation interrupt <-chan struct{} interruptCheckCount uint interruptCheckFrequency uint @@ -438,7 +462,7 @@ func (a *ctxEvalActivation) ResolveName(name string) (any, bool) { return a.parent.ResolveName(name) } -func (a *ctxEvalActivation) Parent() interpreter.Activation { +func (a *ctxEvalActivation) Parent() Activation { return a.parent } @@ -457,7 +481,7 @@ type ctxEvalActivationPool struct { } // Setup initializes a pooled Activation with the ability check for context.Context cancellation -func (p *ctxEvalActivationPool) Setup(vars interpreter.Activation, done <-chan struct{}, interruptCheckRate uint) *ctxEvalActivation { +func (p *ctxEvalActivationPool) Setup(vars Activation, done <-chan struct{}, interruptCheckRate uint) *ctxEvalActivation { a := p.Pool.Get().(*ctxEvalActivation) a.parent = vars a.interrupt = done @@ -506,8 +530,8 @@ func (a *evalActivation) ResolveName(name string) (any, bool) { } } -// Parent implements the interpreter.Activation interface -func (a *evalActivation) Parent() interpreter.Activation { +// Parent implements the Activation interface +func (a *evalActivation) Parent() Activation { return nil } diff --git a/ext/bindings.go b/ext/bindings.go index 50cf4fb3d..81dae50f2 100644 --- a/ext/bindings.go +++ b/ext/bindings.go @@ -224,7 +224,7 @@ func (b *dynamicBlock) ID() int64 { } // Eval implements the Interpretable interface method. -func (b *dynamicBlock) Eval(activation interpreter.Activation) ref.Val { +func (b *dynamicBlock) Eval(activation cel.Activation) ref.Val { sa := b.slotActivationPool.Get().(*dynamicSlotActivation) sa.Activation = activation defer b.clearSlots(sa) @@ -242,7 +242,7 @@ type slotVal struct { } type dynamicSlotActivation struct { - interpreter.Activation + cel.Activation slotExprs []interpreter.Interpretable slotCount int slotVals []*slotVal @@ -295,13 +295,13 @@ func (b *constantBlock) ID() int64 { // Eval implements the interpreter.Interpretable interface method, and will proxy @index prefixed variable // lookups into a set of constant slots determined from the plan step. -func (b *constantBlock) Eval(activation interpreter.Activation) ref.Val { +func (b *constantBlock) Eval(activation cel.Activation) ref.Val { vars := constantSlotActivation{Activation: activation, slots: b.slots, slotCount: b.slotCount} return b.expr.Eval(vars) } type constantSlotActivation struct { - interpreter.Activation + cel.Activation slots traits.Lister slotCount int } diff --git a/repl/BUILD.bazel b/repl/BUILD.bazel index fa6ca2a38..4d4b41b6a 100644 --- a/repl/BUILD.bazel +++ b/repl/BUILD.bazel @@ -34,7 +34,6 @@ go_library( "//common/types:go_default_library", "//common/types/ref:go_default_library", "//ext:go_default_library", - "//interpreter:go_default_library", "//repl/parser:go_default_library", "@com_github_antlr4_go_antlr_v4//:go_default_library", "@dev_cel_expr//conformance/proto2:go_default_library", @@ -46,7 +45,7 @@ go_library( "@org_golang_google_protobuf//reflect/protoreflect:go_default_library", "@org_golang_google_protobuf//reflect/protodesc:go_default_library", "@org_golang_google_protobuf//types/descriptorpb:go_default_library", - ], + ], ) go_test( diff --git a/repl/evaluator.go b/repl/evaluator.go index dcd341621..c26b09c62 100644 --- a/repl/evaluator.go +++ b/repl/evaluator.go @@ -28,7 +28,6 @@ import ( "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/ext" - "github.com/google/cel-go/interpreter" "google.golang.org/protobuf/encoding/prototext" "google.golang.org/protobuf/proto" @@ -666,7 +665,7 @@ func (e *Evaluator) Status() string { // applyContext evaluates the let expressions in the context to build an activation for the given expression. // returns the environment for compiling and planning the top level CEL expression and an activation with the // values of the let expressions. -func (e *Evaluator) applyContext() (*cel.Env, interpreter.Activation, error) { +func (e *Evaluator) applyContext() (*cel.Env, cel.Activation, error) { var vars = make(map[string]any) for _, el := range e.ctx.letVars { @@ -683,7 +682,7 @@ func (e *Evaluator) applyContext() (*cel.Env, interpreter.Activation, error) { } } - act, err := interpreter.NewActivation(vars) + act, err := cel.NewActivation(vars) if err != nil { return nil, nil, err } diff --git a/test/bench/bench.go b/test/bench/bench.go index 10b14dac6..6725faaf7 100644 --- a/test/bench/bench.go +++ b/test/bench/bench.go @@ -33,7 +33,7 @@ type Case struct { // Options indicate additional pieces of configuration such as CEL libraries, variables, and functions. Options []cel.EnvOption - // In is expected to be a map[string]any or interpreter.Activation instance representing the input to the expression. + // In is expected to be a map[string]any or cel.Activation instance representing the input to the expression. In any // Out is the expected CEL valued output. @@ -48,7 +48,7 @@ type DynamicEnvCase struct { // Options indicate additional pieces of configuration such as CEL libraries, variables, and functions. Options func(b *testing.B) *cel.Env - // In is expected to be a map[string]any or interpreter.Activation instance representing the input to the expression. + // In is expected to be a map[string]any or cel.Activation instance representing the input to the expression. In any // Out is the expected CEL valued output. From 9a4b48b7ba3ccb20ffe70d24d30aef8d8fe8db49 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Mon, 10 Feb 2025 09:28:38 -0800 Subject: [PATCH 07/46] ContextEval support for Unknowns (#1126) * Fix for ContextEval with unknown attributes --- cel/cel_test.go | 42 +++++++++++++++++++++++++++++++ cel/program.go | 5 ++++ interpreter/activation.go | 13 +++++----- interpreter/attribute_patterns.go | 2 +- interpreter/interpretable.go | 6 ++--- 5 files changed, 58 insertions(+), 10 deletions(-) diff --git a/cel/cel_test.go b/cel/cel_test.go index f1c4385d6..e48b38b82 100644 --- a/cel/cel_test.go +++ b/cel/cel_test.go @@ -1005,6 +1005,48 @@ func TestContextEval(t *testing.T) { } } +func TestContextEvalUnknowns(t *testing.T) { + env, err := NewEnv( + Variable("groups", ListType(IntType)), + Variable("id", IntType), + ) + if err != nil { + t.Fatalf("NewEnv() failed: %v", err) + } + + pvars, err := PartialVars( + map[string]any{ + "groups": []int{1, 2, 3}, + }, + AttributePattern("id"), + ) + if err != nil { + t.Fatalf("PartialVars() failed: %v", err) + } + + ast, iss := env.Compile(`groups.exists(t, t == id)`) + if iss.Err() != nil { + t.Fatalf("env.Compile() failed: %v", iss.Err()) + } + + prg, err := env.Program(ast, EvalOptions(OptTrackState, OptPartialEval), InterruptCheckFrequency(100)) + if err != nil { + t.Fatalf("env.Program() failed: %v", err) + } + + out, _, err := prg.Eval(pvars) + if err != nil { + t.Fatalf("prg.Eval() failed: %v", err) + } + ctxOut, _, err := prg.ContextEval(context.Background(), pvars) + if err != nil { + t.Fatalf("prg.ContextEval() failed: %v", err) + } + if !reflect.DeepEqual(out, ctxOut) { + t.Errorf("got %v, wanted %v", out, ctxOut) + } +} + func BenchmarkContextEval(b *testing.B) { env := testEnv(b, Variable("items", ListType(IntType)), diff --git a/cel/program.go b/cel/program.go index c57932899..144d1f25a 100644 --- a/cel/program.go +++ b/cel/program.go @@ -466,6 +466,11 @@ func (a *ctxEvalActivation) Parent() Activation { return a.parent } +func (a *ctxEvalActivation) AsPartialActivation() (interpreter.PartialActivation, bool) { + pa, ok := a.parent.(interpreter.PartialActivation) + return pa, ok +} + func newCtxEvalActivationPool() *ctxEvalActivationPool { return &ctxEvalActivationPool{ Pool: sync.Pool{ diff --git a/interpreter/activation.go b/interpreter/activation.go index c20d19de1..80105f4ff 100644 --- a/interpreter/activation.go +++ b/interpreter/activation.go @@ -158,7 +158,8 @@ type PartialActivation interface { // partialActivationConverter indicates whether an Activation implementation supports conversion to a PartialActivation type partialActivationConverter interface { - asPartialActivation() (PartialActivation, bool) + // AsPartialActivation converts the current activation to a PartialActivation + AsPartialActivation() (PartialActivation, bool) } // partActivation is the default implementations of the PartialActivation interface. @@ -172,19 +173,19 @@ func (a *partActivation) UnknownAttributePatterns() []*AttributePattern { return a.unknowns } -// asPartialActivation returns the partActivation as a PartialActivation interface. -func (a *partActivation) asPartialActivation() (PartialActivation, bool) { +// AsPartialActivation returns the partActivation as a PartialActivation interface. +func (a *partActivation) AsPartialActivation() (PartialActivation, bool) { return a, true } -func asPartialActivation(vars Activation) (PartialActivation, bool) { +func AsPartialActivation(vars Activation) (PartialActivation, bool) { // Only internal activation instances may implement this interface if pv, ok := vars.(partialActivationConverter); ok { - return pv.asPartialActivation() + return pv.AsPartialActivation() } // Since Activations may be hierarchical, test whether a parent converts to a PartialActivation if vars.Parent() != nil { - return asPartialActivation(vars.Parent()) + return AsPartialActivation(vars.Parent()) } return nil, false } diff --git a/interpreter/attribute_patterns.go b/interpreter/attribute_patterns.go index 7e5c2db0f..7d0759e37 100644 --- a/interpreter/attribute_patterns.go +++ b/interpreter/attribute_patterns.go @@ -358,7 +358,7 @@ func (m *attributeMatcher) AddQualifier(qual Qualifier) (Attribute, error) { func (m *attributeMatcher) Resolve(vars Activation) (any, error) { id := m.NamespacedAttribute.ID() // Bug in how partial activation is resolved, should search parents as well. - partial, isPartial := asPartialActivation(vars) + partial, isPartial := AsPartialActivation(vars) if isPartial { unk, err := m.fac.matchesUnknownPatterns( partial, diff --git a/interpreter/interpretable.go b/interpreter/interpretable.go index 591b7688b..1573523be 100644 --- a/interpreter/interpretable.go +++ b/interpreter/interpretable.go @@ -1370,16 +1370,16 @@ func (f *folder) Parent() Activation { // if they were provided to the input activation, or an empty set if the proxied activation is not partial. func (f *folder) UnknownAttributePatterns() []*AttributePattern { if pv, ok := f.activation.(partialActivationConverter); ok { - if partial, isPartial := pv.asPartialActivation(); isPartial { + if partial, isPartial := pv.AsPartialActivation(); isPartial { return partial.UnknownAttributePatterns() } } return []*AttributePattern{} } -func (f *folder) asPartialActivation() (PartialActivation, bool) { +func (f *folder) AsPartialActivation() (PartialActivation, bool) { if pv, ok := f.activation.(partialActivationConverter); ok { - if _, isPartial := pv.asPartialActivation(); isPartial { + if _, isPartial := pv.AsPartialActivation(); isPartial { return f, true } } From b7c14faa5d55dc6e7a7890be5a9cdddfd13eedf1 Mon Sep 17 00:00:00 2001 From: origolucis Date: Tue, 11 Feb 2025 02:28:58 +0900 Subject: [PATCH 08/46] Fix godoc formatting (#1127) --- cel/library.go | 7 +++---- ext/lists.go | 3 +-- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/cel/library.go b/cel/library.go index c0aef5019..148516c37 100644 --- a/cel/library.go +++ b/cel/library.go @@ -263,7 +263,7 @@ func (stdLibrary) ProgramOptions() []ProgramOption { // be expressed with `optMap`. // // msg.?elements.optFlatMap(e, e[?0]) // return the first element if present. - +// // # First // // Introduced in version: 2 @@ -272,7 +272,7 @@ func (stdLibrary) ProgramOptions() []ProgramOption { // optional.None. // // [1, 2, 3].first().value() == 1 - +// // # Last // // Introduced in version: 2 @@ -283,7 +283,7 @@ func (stdLibrary) ProgramOptions() []ProgramOption { // [1, 2, 3].last().value() == 3 // // This is syntactic sugar for msg.elements[msg.elements.size()-1]. - +// // # Unwrap / UnwrapOpt // // Introduced in version: 2 @@ -293,7 +293,6 @@ func (stdLibrary) ProgramOptions() []ProgramOption { // // optional.unwrap([optional.of(42), optional.none()]) == [42] // [optional.of(42), optional.none()].unwrapOpt() == [42] - func OptionalTypes(opts ...OptionalTypesOption) EnvOption { lib := &optionalLib{version: math.MaxUint32} for _, opt := range opts { diff --git a/ext/lists.go b/ext/lists.go index 675ea8672..9a3cce37e 100644 --- a/ext/lists.go +++ b/ext/lists.go @@ -134,7 +134,7 @@ var comparableTypes = []*cel.Type{ // // .sortBy(, ) -> // keyExpr returns a value in {int, uint, double, bool, duration, timestamp, string, bytes} - +// // Examples: // // [ @@ -143,7 +143,6 @@ var comparableTypes = []*cel.Type{ // Player { name: "baz", score: 1000 }, // ].sortBy(e, e.score).map(e, e.name) // == ["bar", "foo", "baz"] - func Lists(options ...ListsOption) cel.EnvOption { l := &listsLib{version: math.MaxUint32} for _, o := range options { From fddae56038f97253b97c971deecd76cec39fb89e Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Mon, 10 Feb 2025 17:48:57 -0800 Subject: [PATCH 09/46] Canonical environment description and stdlib subsetting (#1125) * Canonical CEL environment with stdlib subsetting support --- cel/BUILD.bazel | 1 + cel/cel_test.go | 84 +++ cel/decls.go | 12 +- cel/library.go | 57 +- checker/checker_test.go | 3 - checker/env_test.go | 20 +- common/ast/navigable_test.go | 4 - common/decls/decls.go | 4 + common/env/BUILD.bazel | 18 + common/env/env.go | 393 ++++++++++ common/env/env_test.go | 701 ++++++++++++++++++ common/types/types.go | 4 +- common/types/types_test.go | 4 +- go.sum | 4 - interpreter/interpreter_test.go | 4 - policy/BUILD.bazel | 1 + policy/compiler_test.go | 8 +- policy/config.go | 258 ++----- policy/config_test.go | 64 +- policy/go.mod | 2 +- policy/go.sum | 4 +- policy/helper_test.go | 5 +- policy/parser.go | 2 +- policy/testdata/context_pb/config.yaml | 4 +- policy/testdata/k8s/config.yaml | 19 +- policy/testdata/nested_rule/config.yaml | 9 +- policy/testdata/pb/config.yaml | 3 +- policy/testdata/required_labels/config.yaml | 18 +- .../restricted_destinations/base_config.yaml | 34 +- .../restricted_destinations/config.yaml | 34 +- 30 files changed, 1398 insertions(+), 380 deletions(-) create mode 100644 common/env/BUILD.bazel create mode 100644 common/env/env.go create mode 100644 common/env/env_test.go diff --git a/cel/BUILD.bazel b/cel/BUILD.bazel index 81549fb4c..d89595821 100644 --- a/cel/BUILD.bazel +++ b/cel/BUILD.bazel @@ -29,6 +29,7 @@ go_library( "//common/ast:go_default_library", "//common/containers:go_default_library", "//common/decls:go_default_library", + "//common/env:go_default_library", "//common/functions:go_default_library", "//common/operators:go_default_library", "//common/overloads:go_default_library", diff --git a/cel/cel_test.go b/cel/cel_test.go index e48b38b82..e85c5786f 100644 --- a/cel/cel_test.go +++ b/cel/cel_test.go @@ -31,6 +31,7 @@ import ( "github.com/google/cel-go/checker" celast "github.com/google/cel-go/common/ast" + "github.com/google/cel-go/common/env" "github.com/google/cel-go/common/operators" "github.com/google/cel-go/common/overloads" "github.com/google/cel-go/common/types" @@ -359,6 +360,89 @@ func TestExtendStdlibFunction(t *testing.T) { } } +func TestSubsetStdLib(t *testing.T) { + env, err := NewCustomEnv(StdLib(StdLibSubset( + &env.LibrarySubset{ + IncludeMacros: []string{"has"}, + IncludeFunctions: []*env.Function{ + {Name: operators.Equals}, + {Name: operators.NotEquals}, + {Name: operators.LogicalAnd}, + {Name: operators.LogicalOr}, + {Name: operators.LogicalNot}, + {Name: overloads.Size, Overloads: []*env.Overload{{ID: "list_size"}}}, + }, + }, + ))) + if err != nil { + t.Fatalf("StdLib() subsetting failed: %v", err) + } + tests := []struct { + name string + expr string + compiles bool + want ref.Val + }{ + { + name: "has macro", + expr: "!has({}.a)", + compiles: true, + want: types.True, + }, + { + name: "not equals", + expr: "has({}.a) != true", + compiles: true, + want: types.True, + }, + { + name: "logical operators", + expr: "has({}.a) != true && has({'b': 1}.b) == true", + compiles: true, + want: types.True, + }, + { + name: "list size - allowed", + expr: "[1, 2, 3].size()", + compiles: true, + want: types.Int(3), + }, + { + name: "excluded macro", + expr: "[1, 2, 3].exists(i, i != 0)", + compiles: false, + }, + { + name: "string size - not allowed", + expr: "'hello'.size()", + compiles: false, + }, + } + for _, tst := range tests { + tc := tst + t.Run(tc.name, func(t *testing.T) { + ast, iss := env.Compile(tc.expr) + if tc.compiles && iss.Err() != nil { + t.Fatalf("env.Compile(%q) failed: %v", tc.expr, iss.Err()) + } + if !tc.compiles && iss.Err() != nil { + return + } + prg, err := env.Program(ast) + if err != nil { + t.Fatalf("env.Program() failed: %v", err) + } + out, _, err := prg.Eval(NoVars()) + if err != nil { + t.Fatalf("prg.Eval() failed: %s", err) + } + if out.Equal(tc.want) != types.True { + t.Errorf("prg.Eval() got %v, wanted %v", out, tc.want) + } + }) + } +} + func TestCustomTypes(t *testing.T) { reg := types.NewEmptyRegistry() env := testEnv(t, diff --git a/cel/decls.go b/cel/decls.go index 7a2bd9b7c..eedc909bb 100644 --- a/cel/decls.go +++ b/cel/decls.go @@ -148,6 +148,16 @@ func Variable(name string, t *Type) EnvOption { } } +// VariableDecls configures a set of fully defined cel.VariableDecl instances in the environment. +func VariableDecls(vars ...*decls.VariableDecl) EnvOption { + return func(e *Env) (*Env, error) { + for _, v := range vars { + e.variables = append(e.variables, v) + } + return e, nil + } +} + // Function defines a function and overloads with optional singleton or per-overload bindings. // // Using Function is roughly equivalent to calling Declarations() to declare the function signatures @@ -209,7 +219,7 @@ func ExcludeOverloads(overloadIDs ...string) OverloadSelector { return decls.ExcludeOverloads(overloadIDs...) } -// FunctionDecls provides one or more fully formed function declaration to be added to the environment. +// FunctionDecls provides one or more fully formed function declarations to be added to the environment. func FunctionDecls(funcs ...*decls.FunctionDecl) EnvOption { return func(e *Env) (*Env, error) { var err error diff --git a/cel/library.go b/cel/library.go index 148516c37..1d081852c 100644 --- a/cel/library.go +++ b/cel/library.go @@ -22,6 +22,8 @@ import ( "time" "github.com/google/cel-go/common/ast" + "github.com/google/cel-go/common/decls" + "github.com/google/cel-go/common/env" "github.com/google/cel-go/common/operators" "github.com/google/cel-go/common/overloads" "github.com/google/cel-go/common/stdlib" @@ -94,26 +96,61 @@ func Lib(l Library) EnvOption { } } +// StdLibOption specifies a functional option for configuring the standard CEL library. +type StdLibOption func(*stdLibrary) *stdLibrary + +// StdLibSubset configures the standard library to use a subset of its functions and macros. +func StdLibSubset(subset *env.LibrarySubset) StdLibOption { + return func(lib *stdLibrary) *stdLibrary { + lib.subset = subset + return lib + } +} + // StdLib returns an EnvOption for the standard library of CEL functions and macros. -func StdLib() EnvOption { - return Lib(stdLibrary{}) +func StdLib(opts ...StdLibOption) EnvOption { + lib := &stdLibrary{} + for _, o := range opts { + lib = o(lib) + } + return Lib(lib) } // stdLibrary implements the Library interface and provides functional options for the core CEL // features documented in the specification. -type stdLibrary struct{} +type stdLibrary struct { + subset *env.LibrarySubset +} // LibraryName implements the SingletonLibrary interface method. -func (stdLibrary) LibraryName() string { +func (*stdLibrary) LibraryName() string { return "cel.lib.std" } // CompileOptions returns options for the standard CEL function declarations and macros. -func (stdLibrary) CompileOptions() []EnvOption { +func (lib *stdLibrary) CompileOptions() []EnvOption { + funcs := stdlib.Functions() + macros := StandardMacros + if lib.subset != nil { + subMacros := []Macro{} + for _, m := range macros { + if lib.subset.SubsetMacro(m.Function()) { + subMacros = append(subMacros, m) + } + } + macros = subMacros + subFuncs := []*decls.FunctionDecl{} + for _, fn := range funcs { + if f, include := lib.subset.SubsetFunction(fn); include { + subFuncs = append(subFuncs, f) + } + } + funcs = subFuncs + } return []EnvOption{ func(e *Env) (*Env, error) { var err error - for _, fn := range stdlib.Functions() { + for _, fn := range funcs { existing, found := e.functions[fn.Name()] if found { fn, err = existing.Merge(fn) @@ -125,16 +162,12 @@ func (stdLibrary) CompileOptions() []EnvOption { } return e, nil }, - func(e *Env) (*Env, error) { - e.variables = append(e.variables, stdlib.Types()...) - return e, nil - }, - Macros(StandardMacros...), + Macros(macros...), } } // ProgramOptions returns function implementations for the standard CEL functions. -func (stdLibrary) ProgramOptions() []ProgramOption { +func (*stdLibrary) ProgramOptions() []ProgramOption { return []ProgramOption{} } diff --git a/checker/checker_test.go b/checker/checker_test.go index c689c52dc..23b17f3ab 100644 --- a/checker/checker_test.go +++ b/checker/checker_test.go @@ -2377,7 +2377,6 @@ func TestCheck(t *testing.T) { t.Fatalf("NewEnv(cont, reg) failed: %v", err) } if !tc.disableStdEnv { - env.AddIdents(stdlib.Types()...) env.AddFunctions(stdlib.Functions()...) } if tc.env.idents != nil { @@ -2467,7 +2466,6 @@ func BenchmarkCheck(b *testing.B) { b.Fatalf("NewEnv(cont, reg) failed: %v", err) } if !tc.disableStdEnv { - env.AddIdents(stdlib.Types()...) env.AddFunctions(stdlib.Functions()...) } if tc.env.idents != nil { @@ -2583,7 +2581,6 @@ func TestCheckErrorData(t *testing.T) { if err != nil { t.Fatalf("NewEnv(cont, reg) failed: %v", err) } - env.AddIdents(stdlib.Types()...) env.AddFunctions(stdlib.Functions()...) _, iss = Check(ast, src, env) if len(iss.GetErrors()) != 1 { diff --git a/checker/env_test.go b/checker/env_test.go index 44d1894be..2ec7f13fa 100644 --- a/checker/env_test.go +++ b/checker/env_test.go @@ -26,16 +26,6 @@ import ( "github.com/google/cel-go/parser" ) -func TestOverlappingIdentifier(t *testing.T) { - env := newStdEnv(t) - err := env.AddIdents(decls.NewVariable("int", types.TypeType)) - if err == nil { - t.Error("Got nil, wanted error") - } else if !strings.Contains(err.Error(), "overlapping identifier") { - t.Errorf("Got %v, wanted overlapping identifier error", err) - } -} - func TestOverlappingMacro(t *testing.T) { env := newStdEnv(t) hasFn, err := decls.NewFunction("has", @@ -92,10 +82,6 @@ func BenchmarkCopyDeclarations(b *testing.B) { if err != nil { b.Fatalf("NewEnv() failed: %v", err) } - err = env.AddIdents(stdlib.Types()...) - if err != nil { - b.Fatalf("env.AddIdents(stdlib.Types()...) failed: %v", err) - } err = env.AddFunctions(stdlib.Functions()...) if err != nil { b.Fatalf("env.AddFunctions(stdlib.Functions()...) failed: %v", err) @@ -111,13 +97,9 @@ func newStdEnv(t *testing.T) *Env { if err != nil { t.Fatalf("NewEnv() failed: %v", err) } - err = env.AddIdents(stdlib.Types()...) - if err != nil { - t.Fatalf("env.Add(stdlib.TypeExprDecls()...) failed: %v", err) - } err = env.AddFunctions(stdlib.Functions()...) if err != nil { - t.Fatalf("env.Add(stdlib.FunctionExprDecls()...) failed: %v", err) + t.Fatalf("env.Add(stdlib.Functions()...) failed: %v", err) } return env } diff --git a/common/ast/navigable_test.go b/common/ast/navigable_test.go index 5dab4331a..5afb23ab7 100644 --- a/common/ast/navigable_test.go +++ b/common/ast/navigable_test.go @@ -594,10 +594,6 @@ func newTestEnv(t testing.TB, cont *containers.Container, reg *types.Registry) * if err != nil { t.Fatalf("checker.NewEnv(%v, %v) failed: %v", cont, reg, err) } - err = env.AddIdents(stdlib.Types()...) - if err != nil { - t.Fatalf("env.Add(stdlib.Types()...) failed: %v", err) - } err = env.AddFunctions(stdlib.Functions()...) if err != nil { t.Fatalf("env.Add(stdlib.Functions()...) failed: %v", err) diff --git a/common/decls/decls.go b/common/decls/decls.go index 451a6c0d6..df05d2198 100644 --- a/common/decls/decls.go +++ b/common/decls/decls.go @@ -148,6 +148,10 @@ func (f *FunctionDecl) Merge(other *FunctionDecl) (*FunctionDecl, error) { return merged, nil } +// FunctionSubsetter subsets a function declaration or returns nil and false if the function +// subset was empty. +type FunctionSubsetter func(fn *FunctionDecl) (*FunctionDecl, bool) + // OverloadSelector selects an overload associated with a given function when it returns true. // // Used in combination with the Subset method. diff --git a/common/env/BUILD.bazel b/common/env/BUILD.bazel new file mode 100644 index 000000000..148d49c14 --- /dev/null +++ b/common/env/BUILD.bazel @@ -0,0 +1,18 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +go_library( + name = "go_default_library", + srcs = [ + "env.go", + ], + importpath = "github.com/google/cel-go/common/env", + deps = [ + "//common/decls:go_default_library", + "//common/types:go_default_library", + ], +) diff --git a/common/env/env.go b/common/env/env.go new file mode 100644 index 000000000..3595b388c --- /dev/null +++ b/common/env/env.go @@ -0,0 +1,393 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package env provides a representation of a CEL environment. +package env + +import ( + "errors" + "fmt" + "math" + "strconv" + + "github.com/google/cel-go/common/decls" + "github.com/google/cel-go/common/types" +) + +// NewConfig creates an instance of a YAML serializable CEL environment configuration. +func NewConfig() *Config { + return &Config{ + Imports: []*Import{}, + Extensions: []*Extension{}, + Variables: []*Variable{}, + Functions: []*Function{}, + } +} + +// Config represents a serializable form of the CEL environment configuration. +// +// Note: custom validations, feature flags, and performance tuning parameters are +// not (yet) considered part of the core CEL environment configuration and should +// be managed separately until a common convention for such configuration settings +// can be developed. +type Config struct { + Name string `yaml:"name"` + Description string `yaml:"description,omitempty"` + Container string `yaml:"container,omitempty"` + Imports []*Import `yaml:"imports,omitempty"` + StdLib *LibrarySubset `yaml:"stdlib,omitempty"` + Extensions []*Extension `yaml:"extensions,omitempty"` + ContextVariable *ContextVariable `yaml:"context_variable,omitempty"` + Variables []*Variable `yaml:"variables,omitempty"` + Functions []*Function `yaml:"functions,omitempty"` +} + +// Import represents a type name that will be appreviated by its simple name using +// the cel.Abbrevs() option. +type Import struct { + Name string `yaml:"name"` +} + +// Variable represents a typed variable declaration which will be published via the +// cel.VariableDecls() option. +type Variable struct { + Name string `yaml:"name"` + Description string `yaml:"description"` + + // Type represents the type declaration for the variable. + // + // Deprecated: use the embedded *TypeDesc fields directly. + Type *TypeDesc `yaml:"type,omitempty"` + + // TypeDesc is an embedded set of fields allowing for the specification of the Variable type. + *TypeDesc `yaml:",inline"` +} + +// GetType returns the variable type description. +// +// Note, if both the embedded TypeDesc and the field Type are non-nil, the embedded TypeDesc will +// take precedence. +func (vd *Variable) GetType() *TypeDesc { + if vd == nil { + return nil + } + if vd.TypeDesc != nil { + return vd.TypeDesc + } + if vd.Type != nil { + return vd.Type + } + return nil +} + +// AsCELVariable converts the serializable form of the Variable into a CEL environment declaration. +func (vd *Variable) AsCELVariable(tp types.Provider) (*decls.VariableDecl, error) { + if vd == nil { + return nil, errors.New("nil Variable cannot be converted to a VariableDecl") + } + if vd.Name == "" { + return nil, errors.New("invalid variable, must declare a name") + } + if vd.GetType() != nil { + t, err := vd.GetType().AsCELType(tp) + if err != nil { + return nil, fmt.Errorf("invalid variable type for '%s': %w", vd.Name, err) + } + return decls.NewVariable(vd.Name, t), nil + } + return nil, fmt.Errorf("invalid variable '%s', no type specified", vd.Name) +} + +// ContextVariable represents a structured message whose fields are to be treated as the top-level +// variable identifiers within CEL expressions. +type ContextVariable struct { + // TypeName represents the fully qualified typename of the context variable. + // Currently, only protobuf types are supported. + TypeName string `yaml:"type_name"` +} + +// Function represents the serializable format of a function and its overloads. +type Function struct { + Name string `yaml:"name"` + Description string `yaml:"description"` + Overloads []*Overload `yaml:"overloads"` +} + +// AsCELFunction converts the serializable form of the Function into CEL environment declaration. +func (fn *Function) AsCELFunction(tp types.Provider) (*decls.FunctionDecl, error) { + if fn == nil { + return nil, errors.New("nil Function cannot be converted to a FunctionDecl") + } + if fn.Name == "" { + return nil, errors.New("invalid function, must declare a name") + } + if len(fn.Overloads) == 0 { + return nil, fmt.Errorf("invalid function %s, must declare an overload", fn.Name) + } + overloads := make([]decls.FunctionOpt, len(fn.Overloads)) + var err error + for i, o := range fn.Overloads { + overloads[i], err = o.AsFunctionOption(tp) + if err != nil { + return nil, err + } + } + return decls.NewFunction(fn.Name, overloads...) +} + +// Overload represents the serializable format of a function overload. +type Overload struct { + ID string `yaml:"id"` + Description string `yaml:"description"` + Target *TypeDesc `yaml:"target"` + Args []*TypeDesc `yaml:"args"` + Return *TypeDesc `yaml:"return"` +} + +// AsFunctionOption converts the serializable form of the Overload into a function declaration option. +func (od *Overload) AsFunctionOption(tp types.Provider) (decls.FunctionOpt, error) { + if od == nil { + return nil, errors.New("nil Overload cannot be converted to a FunctionOpt") + } + args := make([]*types.Type, len(od.Args)) + var err error + for i, a := range od.Args { + args[i], err = a.AsCELType(tp) + if err != nil { + return nil, err + } + } + if od.Return == nil { + return nil, fmt.Errorf("missing return type on overload: %v", od.ID) + } + result, err := od.Return.AsCELType(tp) + if err != nil { + return nil, err + } + if od.Target != nil { + t, err := od.Target.AsCELType(tp) + if err != nil { + return nil, err + } + args = append([]*types.Type{t}, args...) + return decls.MemberOverload(od.ID, args, result), nil + } + return decls.Overload(od.ID, args, result), nil +} + +// Extension represents a named and optionally versioned extension library configured in the environment. +type Extension struct { + // Name is either the LibraryName() or some short-hand simple identifier which is understood by the config-handler. + Name string `yaml:"name"` + + // Version may either be an unsigned long value or the string 'latest'. If empty, the value is treated as '0'. + Version string `yaml:"version,omitempty"` +} + +// GetVersion returns the parsed version string, or an error if the version cannot be parsed. +func (e *Extension) GetVersion() (uint32, error) { + if e == nil { + return 0, errors.New("nil Extension cannot produce a version") + } + if e.Version == "latest" { + return math.MaxUint32, nil + } + if e.Version == "" { + return uint32(0), nil + } + ver, err := strconv.ParseUint(e.Version, 10, 32) + if err != nil { + return 0, fmt.Errorf("error parsing uint version: %w", err) + } + return uint32(ver), nil +} + +// LibrarySubset indicates a subset of the macros and function supported by a subsettable library. +type LibrarySubset struct { + // DisableMacros disables macros for the given library. + DisableMacros bool `yaml:"disable_macros"` + + // IncludeMacros specifies a set of macro function names to include in the subset. + IncludeMacros []string `yaml:"include_macros"` + + // ExcludeMacros specifies a set of macro function names to exclude from the subset. + // Note: if IncludeMacros is non-empty, then ExcludeFunctions is ignored. + ExcludeMacros []string `yaml:"exclude_macros"` + + // IncludeFunctions specifies a set of functions to include in the subset. + // + // Note: the overloads specified in the subset need only specify their ID. + // Note: if IncludeFunctions is non-empty, then ExcludeFunctions is ignored. + IncludeFunctions []*Function `yaml:"include_functions"` + + // ExcludeFunctions specifies the set of functions to exclude from the subset. + // + // Note: the overloads specified in the subset need only specify their ID. + ExcludeFunctions []*Function `yaml:"exclude_functions"` +} + +// SubsetFunction produces a function declaration which matches the supported subset, or nil +// if the function is not supported by the LibrarySubset. +// +// For IncludeFunctions, if the function does not specify a set of overloads to include, the +// whole function definition is included. If overloads are set, then a new function which +// includes only the specified overloads is produced. +// +// For ExcludeFunctions, if the function does not specify a set of overloads to exclude, the +// whole function definition is excluded. If overloads are set, then a new function which +// includes only the permitted overloads is produced. +func (lib *LibrarySubset) SubsetFunction(fn *decls.FunctionDecl) (*decls.FunctionDecl, bool) { + // When lib is null, it should indicate that all values are included in the subset. + if lib == nil { + return fn, true + } + if len(lib.IncludeFunctions) != 0 { + for _, include := range lib.IncludeFunctions { + if include.Name != fn.Name() { + continue + } + if len(include.Overloads) == 0 { + return fn, true + } + overloadIDs := make([]string, len(include.Overloads)) + for i, o := range include.Overloads { + overloadIDs[i] = o.ID + } + return fn.Subset(decls.IncludeOverloads(overloadIDs...)), true + } + return nil, false + } + if len(lib.ExcludeFunctions) != 0 { + for _, exclude := range lib.ExcludeFunctions { + if exclude.Name != fn.Name() { + continue + } + if len(exclude.Overloads) == 0 { + return nil, false + } + overloadIDs := make([]string, len(exclude.Overloads)) + for i, o := range exclude.Overloads { + overloadIDs[i] = o.ID + } + return fn.Subset(decls.ExcludeOverloads(overloadIDs...)), true + } + return fn, true + } + return fn, true +} + +// SubsetMacro indicates whether the macro function should be included in the library subset. +func (lib *LibrarySubset) SubsetMacro(macroFunction string) bool { + // When lib is null, it should indicate that all values are included in the subset. + if lib == nil { + return true + } + if lib.DisableMacros { + return false + } + if len(lib.IncludeMacros) != 0 { + for _, name := range lib.IncludeMacros { + if name == macroFunction { + return true + } + } + return false + } + if len(lib.ExcludeMacros) != 0 { + for _, name := range lib.ExcludeMacros { + if name == macroFunction { + return false + } + } + return true + } + return true +} + +// TypeDesc represents the serializable format of a CEL *types.Type value. +type TypeDesc struct { + TypeName string `yaml:"type_name"` + Params []*TypeDesc `yaml:"params"` + IsTypeParam bool `yaml:"is_type_param"` +} + +// AsCELType converts the serializable object to a *types.Type value. +func (td *TypeDesc) AsCELType(tp types.Provider) (*types.Type, error) { + if td == nil { + return nil, errors.New("nil TypeDesc cannot be converted to a Type instance") + } + if td.TypeName == "" { + return nil, errors.New("invalid type description, declare a type name") + } + var err error + switch td.TypeName { + case "dyn": + return types.DynType, nil + case "map": + if len(td.Params) == 2 { + kt, err := td.Params[0].AsCELType(tp) + if err != nil { + return nil, err + } + vt, err := td.Params[1].AsCELType(tp) + if err != nil { + return nil, err + } + return types.NewMapType(kt, vt), nil + } + return nil, fmt.Errorf("map type has unexpected param count: %d", len(td.Params)) + case "list": + if len(td.Params) == 1 { + et, err := td.Params[0].AsCELType(tp) + if err != nil { + return nil, err + } + return types.NewListType(et), nil + } + return nil, fmt.Errorf("list type has unexpected param count: %d", len(td.Params)) + case "optional_type": + if len(td.Params) == 1 { + et, err := td.Params[0].AsCELType(tp) + if err != nil { + return nil, err + } + return types.NewOptionalType(et), nil + } + return nil, fmt.Errorf("optional_type has unexpected param count: %d", len(td.Params)) + default: + if td.IsTypeParam { + return types.NewTypeParamType(td.TypeName), nil + } + if msgType, found := tp.FindStructType(td.TypeName); found { + // First parameter is the type name. + return msgType.Parameters()[0], nil + } + t, found := tp.FindIdent(td.TypeName) + if !found { + return nil, fmt.Errorf("undefined type name: %v", td.TypeName) + } + _, ok := t.(*types.Type) + if ok && len(td.Params) == 0 { + return t.(*types.Type), nil + } + params := make([]*types.Type, len(td.Params)) + for i, p := range td.Params { + params[i], err = p.AsCELType(tp) + if err != nil { + return nil, err + } + } + return types.NewOpaqueType(td.TypeName, params...), nil + } +} diff --git a/common/env/env_test.go b/common/env/env_test.go new file mode 100644 index 000000000..a59347aaa --- /dev/null +++ b/common/env/env_test.go @@ -0,0 +1,701 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package env + +import ( + "errors" + "math" + "reflect" + "strings" + "testing" + + "github.com/google/cel-go/common/decls" + "github.com/google/cel-go/common/types" +) + +func TestConfig(t *testing.T) { + conf := NewConfig() + if conf == nil { + t.Fatal("got nil config, wanted non-nil value") + } +} + +func TestVariableGetType(t *testing.T) { + tests := []struct { + name string + v *Variable + t *TypeDesc + }{ + { + name: "nil-safety check", + v: nil, + t: nil, + }, + { + name: "nil type access", + v: &Variable{}, + t: nil, + }, + { + name: "nested type desc", + v: &Variable{TypeDesc: &TypeDesc{}}, + t: &TypeDesc{}, + }, + { + name: "field type desc", + v: &Variable{Type: &TypeDesc{}}, + t: &TypeDesc{}, + }, + { + name: "nested type desc precedence", + v: &Variable{ + TypeDesc: &TypeDesc{TypeName: "type.name.EmbeddedType"}, + Type: &TypeDesc{TypeName: "type.name.FieldType"}, + }, + t: &TypeDesc{TypeName: "type.name.EmbeddedType"}, + }, + } + + for _, tst := range tests { + tc := tst + t.Run(tc.name, func(t *testing.T) { + if !reflect.DeepEqual(tc.v.GetType(), tc.t) { + t.Errorf("GetType() got %v, wanted %v", tc.v.GetType(), tc.t) + } + }) + } +} + +func TestVariableAsCELVariable(t *testing.T) { + tests := []struct { + name string + v *Variable + want any + }{ + { + name: "nil-safety check", + v: nil, + want: errors.New("nil Variable"), + }, + { + name: "no variable name", + v: &Variable{}, + want: errors.New("invalid variable"), + }, + { + name: "no type", + v: &Variable{ + Name: "hello", + }, + want: errors.New("no type specified"), + }, + { + name: "bad type", + v: &Variable{ + Name: "hello", + TypeDesc: &TypeDesc{}, + }, + want: errors.New("declare a type name"), + }, + { + name: "int type", + v: &Variable{ + Name: "int_var", + TypeDesc: &TypeDesc{TypeName: "int"}, + }, + want: decls.NewVariable("int_var", types.IntType), + }, + { + name: "uint type", + v: &Variable{ + Name: "uint_var", + TypeDesc: &TypeDesc{TypeName: "uint"}, + }, + want: decls.NewVariable("uint_var", types.UintType), + }, + { + name: "dyn type", + v: &Variable{ + Name: "dyn_var", + TypeDesc: &TypeDesc{TypeName: "dyn"}, + }, + want: decls.NewVariable("dyn_var", types.DynType), + }, + { + name: "list type", + v: &Variable{ + Name: "list_var", + TypeDesc: &TypeDesc{TypeName: "list", Params: []*TypeDesc{{TypeName: "T", IsTypeParam: true}}}, + }, + want: decls.NewVariable("list_var", types.NewListType(types.NewTypeParamType("T"))), + }, + { + name: "map type", + v: &Variable{ + Name: "map_var", + TypeDesc: &TypeDesc{ + TypeName: "map", + Params: []*TypeDesc{ + {TypeName: "string"}, + {TypeName: "optional_type", + Params: []*TypeDesc{{TypeName: "T", IsTypeParam: true}}}, + }, + }, + }, + want: decls.NewVariable("map_var", + types.NewMapType(types.StringType, types.NewOptionalType(types.NewTypeParamType("T")))), + }, + { + name: "set type", + v: &Variable{ + Name: "set_var", + TypeDesc: &TypeDesc{ + TypeName: "set", + Params: []*TypeDesc{ + {TypeName: "string"}, + }, + }, + }, + want: decls.NewVariable("set_var", types.NewOpaqueType("set", types.StringType)), + }, + { + name: "string type - nested type precedence", + v: &Variable{ + Name: "hello", + TypeDesc: &TypeDesc{TypeName: "string"}, + Type: &TypeDesc{TypeName: "int"}, + }, + want: decls.NewVariable("hello", types.StringType), + }, + { + name: "wrapper type variable", + v: &Variable{ + Name: "msg", + TypeDesc: &TypeDesc{TypeName: "google.protobuf.StringValue"}, + }, + want: decls.NewVariable("msg", types.NewNullableType(types.StringType)), + }, + } + + tp, err := types.NewRegistry() + if err != nil { + t.Fatalf("types.NewRegistry() failed: %v", err) + } + tp.RegisterType(types.NewOpaqueType("set", types.NewTypeParamType("T"))) + for _, tst := range tests { + tc := tst + t.Run(tc.name, func(t *testing.T) { + gotVar, err := tc.v.AsCELVariable(tp) + if err != nil { + wantErr, ok := tc.want.(error) + if !ok { + t.Fatalf("AsCELVariable() got error %v, wanted %v", err, tc.want) + } + if !strings.Contains(err.Error(), wantErr.Error()) { + t.Fatalf("AsCELVariable() got error %v, wanted error contining %v", err, wantErr) + } + return + } + if !gotVar.DeclarationIsEquivalent(tc.want.(*decls.VariableDecl)) { + t.Errorf("AsCELVariable() got %v, wanted %v", gotVar, tc.want) + } + }) + } +} + +func TestFunctionAsCELFunction(t *testing.T) { + tests := []struct { + name string + f *Function + want any + }{ + { + name: "nil function", + f: nil, + want: errors.New("nil Function"), + }, + { + name: "unnamed function", + f: &Function{}, + want: errors.New("must declare a name"), + }, + { + name: "no overloads", + f: &Function{Name: "no_overloads"}, + want: errors.New("must declare an overload"), + }, + { + name: "nil overload", + f: &Function{Name: "no_overloads", Overloads: []*Overload{nil}}, + want: errors.New("nil Overload"), + }, + { + name: "no return type", + f: &Function{Name: "size", + Overloads: []*Overload{ + {ID: "size_string", + Args: []*TypeDesc{{TypeName: "string"}}, + }, + }, + }, + want: errors.New("missing return type"), + }, + { + name: "bad return type", + f: &Function{Name: "size", + Overloads: []*Overload{ + {ID: "size_string", + Args: []*TypeDesc{{TypeName: "string"}}, + Return: &TypeDesc{}, + }, + }, + }, + want: errors.New("invalid type"), + }, + { + name: "bad arg type", + f: &Function{Name: "size", + Overloads: []*Overload{ + {ID: "size_string", + Args: []*TypeDesc{{}}, + Return: &TypeDesc{}, + }, + }, + }, + want: errors.New("invalid type"), + }, + { + name: "bad target type", + f: &Function{Name: "size", + Overloads: []*Overload{ + {ID: "string_size", + Target: &TypeDesc{}, + Args: []*TypeDesc{}, + Return: &TypeDesc{TypeName: "int"}, + }, + }, + }, + want: errors.New("invalid type"), + }, + { + name: "global function", + f: &Function{Name: "size", + Overloads: []*Overload{ + {ID: "size_string", + Args: []*TypeDesc{{TypeName: "string"}}, + Return: &TypeDesc{TypeName: "int"}}, + }, + }, + want: mustNewFunction(t, "size", decls.Overload("size_string", []*types.Type{types.StringType}, types.IntType)), + }, + { + name: "member function", + f: &Function{Name: "size", + Overloads: []*Overload{ + {ID: "string_size", + Target: &TypeDesc{TypeName: "string"}, + Return: &TypeDesc{TypeName: "int"}}, + }, + }, + want: mustNewFunction(t, "size", decls.MemberOverload("string_size", []*types.Type{types.StringType}, types.IntType)), + }, + } + tp, err := types.NewRegistry() + if err != nil { + t.Fatalf("types.NewRegistry() failed: %v", err) + } + tp.RegisterType(types.NewOpaqueType("set", types.NewTypeParamType("T"))) + for _, tst := range tests { + tc := tst + t.Run(tc.name, func(t *testing.T) { + gotFn, err := tc.f.AsCELFunction(tp) + if err != nil { + wantErr, ok := tc.want.(error) + if !ok { + t.Fatalf("AsCELFunction() got error %v, wanted %v", err, tc.want) + } + if !strings.Contains(err.Error(), wantErr.Error()) { + t.Fatalf("AsCELFunction() got error %v, wanted error contining %v", err, wantErr) + } + return + } + assertFuncEquals(t, gotFn, tc.want.(*decls.FunctionDecl)) + }) + } +} + +func TestTypeDescAsCELTypeErrors(t *testing.T) { + tests := []struct { + name string + t *TypeDesc + want any + }{ + { + name: "nil-safety check", + t: nil, + want: errors.New("nil TypeDesc"), + }, + { + name: "no type name", + t: &TypeDesc{}, + want: errors.New("invalid type"), + }, + { + name: "invalid optional", + t: &TypeDesc{TypeName: "optional"}, + want: errors.New("unexpected param count"), + }, + { + name: "invalid optional param type", + t: &TypeDesc{TypeName: "optional", Params: []*TypeDesc{{}}}, + want: errors.New("invalid type"), + }, + { + name: "invalid list", + t: &TypeDesc{TypeName: "list"}, + want: errors.New("unexpected param count"), + }, + { + name: "invalid list param type", + t: &TypeDesc{TypeName: "list", Params: []*TypeDesc{{}}}, + want: errors.New("invalid type"), + }, + { + name: "invalid map", + t: &TypeDesc{TypeName: "map"}, + want: errors.New("unexpected param count"), + }, + { + name: "invalid map key type", + t: &TypeDesc{TypeName: "map", Params: []*TypeDesc{{}, {}}}, + want: errors.New("invalid type"), + }, + { + name: "invalid map value type", + t: &TypeDesc{TypeName: "map", Params: []*TypeDesc{{TypeName: "string"}, {}}}, + want: errors.New("invalid type"), + }, + { + name: "invalid set", + t: &TypeDesc{TypeName: "set", Params: []*TypeDesc{{}}}, + want: errors.New("invalid type"), + }, + { + name: "undefined type identifier", + t: &TypeDesc{TypeName: "vector"}, + want: errors.New("undefined type"), + }, + } + tp, err := types.NewRegistry() + if err != nil { + t.Fatalf("types.NewRegistry() failed: %v", err) + } + tp.RegisterType(types.NewOpaqueType("set", types.NewTypeParamType("T"))) + for _, tst := range tests { + tc := tst + t.Run(tc.name, func(t *testing.T) { + gotVar, err := tc.t.AsCELType(tp) + if err != nil { + wantErr, ok := tc.want.(error) + if !ok { + t.Fatalf("AsCELType() got error %v, wanted %v", err, tc.want) + } + if !strings.Contains(err.Error(), wantErr.Error()) { + t.Fatalf("AsCELType() got error %v, wanted error contining %v", err, wantErr) + } + return + } + if !reflect.DeepEqual(gotVar, tc.want.(*decls.VariableDecl)) { + t.Errorf("AsCELType() got %v, wanted %v", gotVar, tc.want) + } + }) + } +} + +func TestSubsetFunction(t *testing.T) { + tests := []struct { + name string + lib *LibrarySubset + orig *decls.FunctionDecl + subset *decls.FunctionDecl + included bool + }{ + { + name: "nil lib, included", + orig: mustNewFunction(t, "size", decls.Overload("size_string", []*types.Type{types.StringType}, types.IntType)), + subset: mustNewFunction(t, "size", decls.Overload("size_string", []*types.Type{types.StringType}, types.IntType)), + included: true, + }, + { + name: "empty, included", + lib: &LibrarySubset{}, + orig: mustNewFunction(t, "size", decls.Overload("size_string", []*types.Type{types.StringType}, types.IntType)), + subset: mustNewFunction(t, "size", decls.Overload("size_string", []*types.Type{types.StringType}, types.IntType)), + included: true, + }, + { + name: "lib, not included allow-list", + lib: &LibrarySubset{ + IncludeFunctions: []*Function{ + {Name: "int"}, + }, + }, + orig: mustNewFunction(t, "size", decls.Overload("size_string", []*types.Type{types.StringType}, types.IntType)), + included: false, + }, + { + name: "lib, included whole function", + lib: &LibrarySubset{ + IncludeFunctions: []*Function{ + {Name: "size"}, + }, + }, + orig: mustNewFunction(t, "size", decls.Overload("size_string", []*types.Type{types.StringType}, types.IntType)), + subset: mustNewFunction(t, "size", decls.Overload("size_string", []*types.Type{types.StringType}, types.IntType)), + included: true, + }, + { + name: "lib, included overload subset", + lib: &LibrarySubset{ + IncludeFunctions: []*Function{ + {Name: "size", Overloads: []*Overload{{ID: "size_string"}}}, + }, + }, + orig: mustNewFunction(t, "size", + decls.Overload("size_string", []*types.Type{types.StringType}, types.IntType), + decls.Overload("size_list", []*types.Type{types.NewListType(types.NewTypeParamType("T"))}, types.IntType), + ), + subset: mustNewFunction(t, "size", decls.Overload("size_string", []*types.Type{types.StringType}, types.IntType)), + included: true, + }, + { + name: "lib, included deny-list", + lib: &LibrarySubset{ + ExcludeFunctions: []*Function{ + {Name: "int"}, + }, + }, + orig: mustNewFunction(t, "size", decls.Overload("size_string", []*types.Type{types.StringType}, types.IntType)), + subset: mustNewFunction(t, "size", decls.Overload("size_string", []*types.Type{types.StringType}, types.IntType)), + included: true, + }, + { + name: "lib, excluded whole function", + lib: &LibrarySubset{ + ExcludeFunctions: []*Function{ + {Name: "size"}, + }, + }, + orig: mustNewFunction(t, "size", decls.Overload("size_string", []*types.Type{types.StringType}, types.IntType)), + included: false, + }, + { + name: "lib, excluded partial function", + lib: &LibrarySubset{ + ExcludeFunctions: []*Function{ + {Name: "size", Overloads: []*Overload{{ID: "size_list"}}}, + }, + }, + orig: mustNewFunction(t, "size", + decls.Overload("size_string", []*types.Type{types.StringType}, types.IntType), + decls.Overload("size_list", []*types.Type{types.NewListType(types.NewTypeParamType("T"))}, types.IntType), + ), + subset: mustNewFunction(t, "size", decls.Overload("size_string", []*types.Type{types.StringType}, types.IntType)), + included: true, + }, + } + for _, tst := range tests { + tc := tst + t.Run(tc.name, func(t *testing.T) { + got, included := tc.lib.SubsetFunction(tc.orig) + if included != tc.included { + t.Fatalf("SubsetFunction() got included %t, wanted %t", included, tc.included) + } + if !tc.included { + return + } + assertFuncEquals(t, got, tc.subset) + }) + } +} + +func TestSubsetMacro(t *testing.T) { + tests := []struct { + name string + lib *LibrarySubset + macroName string + included bool + }{ + { + name: "nil lib, included", + macroName: "has", + included: true, + }, + { + name: "empty, included", + lib: &LibrarySubset{}, + macroName: "has", + included: true, + }, + { + name: "empty, included", + lib: &LibrarySubset{DisableMacros: true}, + macroName: "has", + included: false, + }, + { + name: "lib, not included allow-list", + lib: &LibrarySubset{ + IncludeMacros: []string{"exists"}, + }, + macroName: "has", + included: false, + }, + { + name: "lib, included allow-list", + lib: &LibrarySubset{ + IncludeMacros: []string{"exists"}, + }, + macroName: "exists", + included: true, + }, + { + name: "lib, not included deny-list", + lib: &LibrarySubset{ + ExcludeMacros: []string{"exists"}, + }, + macroName: "exists", + included: false, + }, + { + name: "lib, included deny-list", + lib: &LibrarySubset{ + ExcludeMacros: []string{"exists"}, + }, + macroName: "has", + included: true, + }, + } + for _, tst := range tests { + tc := tst + t.Run(tc.name, func(t *testing.T) { + included := tc.lib.SubsetMacro(tc.macroName) + if included != tc.included { + t.Fatalf("SubsetMacro() got included %t, wanted %t", included, tc.included) + } + }) + } +} + +func TestExtensionGetVersion(t *testing.T) { + tests := []struct { + name string + ext *Extension + want any + }{ + { + name: "nil extension", + want: errors.New("nil Extension"), + }, + { + name: "unset version", + ext: &Extension{}, + want: uint32(0), + }, + { + name: "numeric version", + ext: &Extension{Version: "1"}, + want: uint32(1), + }, + { + name: "latest version", + ext: &Extension{Version: "latest"}, + want: uint32(math.MaxUint32), + }, + { + name: "bad version", + ext: &Extension{Version: "1.0"}, + want: errors.New("invalid syntax"), + }, + } + for _, tst := range tests { + tc := tst + t.Run(tc.name, func(t *testing.T) { + ver, err := tc.ext.GetVersion() + if err != nil { + wantErr, ok := tc.want.(error) + if !ok { + t.Fatalf("GetVersion() got error %v, wanted %v", err, tc.want) + } + if !strings.Contains(err.Error(), wantErr.Error()) { + t.Fatalf("GetVersion() got error %v, wanted error contining %v", err, wantErr) + } + return + } + if tc.want.(uint32) != ver { + t.Fatalf("GetVersion() got %d, wanted %v", ver, tc.want) + } + }) + } +} + +func mustNewFunction(t *testing.T, name string, opts ...decls.FunctionOpt) *decls.FunctionDecl { + t.Helper() + fn, err := decls.NewFunction(name, opts...) + if err != nil { + t.Fatalf("decls.NewFunction() failed: %v", err) + } + return fn +} + +func assertFuncEquals(t *testing.T, got, want *decls.FunctionDecl) { + t.Helper() + if got.Name() != want.Name() { + t.Fatalf("got function name %s, wanted %s", got.Name(), want.Name()) + } + if len(got.OverloadDecls()) != len(want.OverloadDecls()) { + t.Fatalf("got overload count %d, wanted %d", len(got.OverloadDecls()), len(want.OverloadDecls())) + } + for i, gotOverload := range got.OverloadDecls() { + wantOverload := want.OverloadDecls()[i] + if gotOverload.ID() != wantOverload.ID() { + t.Errorf("got overload id: %s, wanted: %s", gotOverload.ID(), wantOverload.ID()) + } + if gotOverload.IsMemberFunction() != wantOverload.IsMemberFunction() { + t.Errorf("got is member function %t, wanted %t", gotOverload.IsMemberFunction(), wantOverload.IsMemberFunction()) + } + if len(gotOverload.ArgTypes()) != len(wantOverload.ArgTypes()) { + t.Fatalf("got arg count %d, wanted %d", len(gotOverload.ArgTypes()), len(wantOverload.ArgTypes())) + } + for i, p := range gotOverload.ArgTypes() { + wp := wantOverload.ArgTypes()[i] + if !p.IsExactType(wp) { + t.Errorf("got arg[%d] type %v, wanted %v", i, p, wp) + } + } + if len(gotOverload.TypeParams()) != len(wantOverload.TypeParams()) { + t.Fatalf("got type param count %d, wanted %d", len(gotOverload.TypeParams()), len(wantOverload.TypeParams())) + } + for i, p := range gotOverload.TypeParams() { + wp := wantOverload.TypeParams()[i] + if p != wp { + t.Errorf("got type param[%d] %s, wanted %s", i, p, wp) + } + } + if !gotOverload.ResultType().IsExactType(wantOverload.ResultType()) { + t.Errorf("got result type %v, wanted %v", gotOverload.ResultType(), wantOverload.ResultType()) + } + } +} diff --git a/common/types/types.go b/common/types/types.go index f419beabd..d5ce60f16 100644 --- a/common/types/types.go +++ b/common/types/types.go @@ -164,9 +164,9 @@ var ( traits.SubtractorType, } // ListType represents the runtime list type. - ListType = NewListType(nil) + ListType = NewListType(DynType) // MapType represents the runtime map type. - MapType = NewMapType(nil, nil) + MapType = NewMapType(DynType, DynType) // NullType represents the type of a null value. NullType = &Type{ kind: NullTypeKind, diff --git a/common/types/types_test.go b/common/types/types_test.go index a94fc29a2..14e9a8dee 100644 --- a/common/types/types_test.go +++ b/common/types/types_test.go @@ -92,11 +92,11 @@ func TestTypeString(t *testing.T) { }, { in: ListType, - out: "list()", + out: "list(dyn)", }, { in: MapType, - out: "map(, )", + out: "map(dyn, dyn)", }, } for _, tst := range tests { diff --git a/go.sum b/go.sum index fdf55b76d..fbd276c1e 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,3 @@ -cel.dev/expr v0.18.0 h1:CJ6drgk+Hf96lkLikr4rFf19WrU0BOWEihyZnI2TAzo= -cel.dev/expr v0.18.0/go.mod h1:MrpN08Q+lEBs+bGYdLxxHkZoUSsCp0nSKTs0nTymJgw= -cel.dev/expr v0.19.0 h1:lXuo+nDhpyJSpWxpPVi5cPUwzKb+dsdOiw6IreM5yt0= -cel.dev/expr v0.19.0/go.mod h1:MrpN08Q+lEBs+bGYdLxxHkZoUSsCp0nSKTs0nTymJgw= cel.dev/expr v0.19.1 h1:NciYrtDRIR0lNCnH1LFJegdjspNx9fI59O7TWcua/W4= cel.dev/expr v0.19.1/go.mod h1:MrpN08Q+lEBs+bGYdLxxHkZoUSsCp0nSKTs0nTymJgw= github.com/antlr4-go/antlr/v4 v4.13.0 h1:lxCg3LAv+EUK6t1i0y1V6/SLeUi0eKEKdhQAlS8TVTI= diff --git a/interpreter/interpreter_test.go b/interpreter/interpreter_test.go index 6fc40f4fa..cf1f56de1 100644 --- a/interpreter/interpreter_test.go +++ b/interpreter/interpreter_test.go @@ -2183,10 +2183,6 @@ func newTestEnv(t testing.TB, cont *containers.Container, reg *types.Registry) * if err != nil { t.Fatalf("checker.NewEnv(%v, %v) failed: %v", cont, reg, err) } - err = env.AddIdents(stdlib.Types()...) - if err != nil { - t.Fatalf("env.Add(stdlib.Types()...) failed: %v", err) - } err = env.AddFunctions(stdlib.Functions()...) if err != nil { t.Fatalf("env.Add(stdlib.Functions()...) failed: %v", err) diff --git a/policy/BUILD.bazel b/policy/BUILD.bazel index 5fec59d78..875f523c3 100644 --- a/policy/BUILD.bazel +++ b/policy/BUILD.bazel @@ -37,6 +37,7 @@ go_library( "//common/ast:go_default_library", "//common/containers:go_default_library", "//common/decls:go_default_library", + "//common/env:go_default_library", "//common/operators:go_default_library", "//common/types:go_default_library", "//common/types/ref:go_default_library", diff --git a/policy/compiler_test.go b/policy/compiler_test.go index 34c0b6545..9a4497846 100644 --- a/policy/compiler_test.go +++ b/policy/compiler_test.go @@ -31,8 +31,10 @@ import ( func TestCompile(t *testing.T) { for _, tst := range policyTests { - r := newRunner(t, tst.name, tst.expr, tst.parseOpts, tst.envOpts...) - r.run(t) + t.Run(tst.name, func(t *testing.T) { + r := newRunner(t, tst.name, tst.expr, tst.parseOpts, tst.envOpts...) + r.run(t) + }) } } @@ -171,7 +173,7 @@ func compile(t testing.TB, name string, parseOpts []ParserOption, envOpts []cel. t.Fatalf("env.Extend() with env options %v, failed: %v", config, err) } // Configure declarations - configOpts, err := config.AsEnvOptions(env) + configOpts, err := config.AsEnvOptions(env.CELTypeProvider()) if err != nil { t.Fatalf("config.AsEnvOptions() failed: %v", err) } diff --git a/policy/config.go b/policy/config.go index cab8da1c0..50f2852a2 100644 --- a/policy/config.go +++ b/policy/config.go @@ -17,252 +17,118 @@ package policy import ( "errors" "fmt" - "math" - "strconv" "google.golang.org/protobuf/proto" "github.com/google/cel-go/cel" + "github.com/google/cel-go/common/decls" + "github.com/google/cel-go/common/env" + "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/ext" ) +// NewConfig returns a YAML serializable policy environment. +func NewConfig(e *env.Config) *Config { + return &Config{Config: e} +} + // Config represents a YAML serializable CEL environment configuration. type Config struct { - Name string `yaml:"name"` - Description string `yaml:"description"` - Container string `yaml:"container"` - Extensions []*ExtensionConfig `yaml:"extensions"` - Variables []*VariableDecl `yaml:"variables"` - Functions []*FunctionDecl `yaml:"functions"` + *env.Config } // AsEnvOptions converts the Config value to a collection of cel environment options. -func (c *Config) AsEnvOptions(baseEnv *cel.Env) ([]cel.EnvOption, error) { +func (c *Config) AsEnvOptions(provider types.Provider) ([]cel.EnvOption, error) { envOpts := []cel.EnvOption{} + // Configure the standard lib subset. + if c.StdLib != nil { + envOpts = append(envOpts, func(e *cel.Env) (*cel.Env, error) { + return cel.NewCustomEnv(cel.StdLib(cel.StdLibSubset(c.StdLib))) + }) + } + + // Configure the container if c.Container != "" { envOpts = append(envOpts, cel.Container(c.Container)) } - for _, e := range c.Extensions { - opt, err := e.AsEnvOption(baseEnv) - if err != nil { - return nil, err - } - envOpts = append(envOpts, opt) - } - for _, v := range c.Variables { - opt, err := v.AsEnvOption(baseEnv) - if err != nil { - return nil, err - } - envOpts = append(envOpts, opt) - } - for _, f := range c.Functions { - opt, err := f.AsEnvOption(baseEnv) - if err != nil { - return nil, err - } - envOpts = append(envOpts, opt) - } - return envOpts, nil -} - -// ExtensionFactory accepts a version number and produces a CEL environment associated with the versioned -// extension. -type ExtensionFactory func(uint32) cel.EnvOption - -// ExtensionResolver provides a way to lookup ExtensionFactory instances by extension name. -type ExtensionResolver interface { - // ResolveExtension returns an ExtensionFactory bound to the given name, if one exists. - ResolveExtension(name string) (ExtensionFactory, bool) -} - -// ExtensionConfig represents a YAML serializable definition of a versioned extension library reference. -type ExtensionConfig struct { - Name string `yaml:"name"` - Version string `yaml:"version"` - ExtensionResolver -} -// AsEnvOption converts an ExtensionConfig value to a CEL environment option. -func (ec *ExtensionConfig) AsEnvOption(baseEnv *cel.Env) (cel.EnvOption, error) { - fac, found := extFactories[ec.Name] - if !found && ec.ExtensionResolver != nil { - fac, found = ec.ResolveExtension(ec.Name) - } - if !found { - return nil, fmt.Errorf("unrecognized extension: %s", ec.Name) + // Configure abbreviations + for _, imp := range c.Imports { + envOpts = append(envOpts, cel.Abbrevs(imp.Name)) } - // If the version is 'latest', set the version value to the max uint. - if ec.Version == "latest" { - return fac(math.MaxUint32), nil - } - if ec.Version == "" { - return fac(0), nil - } - ver, err := strconv.ParseUint(ec.Version, 10, 32) - if err != nil { - return nil, fmt.Errorf("error parsing uint version: %w", err) - } - return fac(uint32(ver)), nil -} - -// VariableDecl represents a YAML serializable CEL variable declaration. -type VariableDecl struct { - Name string `yaml:"name"` - Type *TypeDecl `yaml:"type"` - ContextProto string `yaml:"context_proto"` -} -// AsEnvOption converts a VariableDecl type to a CEL environment option. -// -// Note, variable definitions with differing type definitions will result in an error during -// the compile step. -func (vd *VariableDecl) AsEnvOption(baseEnv *cel.Env) (cel.EnvOption, error) { - if vd.Name != "" { - t, err := vd.Type.AsCELType(baseEnv) - if err != nil { - return nil, fmt.Errorf("invalid variable type for '%s': %w", vd.Name, err) + // Configure the context variable declaration + if c.ContextVariable != nil { + if len(c.Variables) > 0 { + return nil, errors.New("either the context_variable or the variables may be set, but not both") } - return cel.Variable(vd.Name, t), nil - } - if vd.ContextProto != "" { - if _, found := baseEnv.CELTypeProvider().FindStructType(vd.ContextProto); !found { - return nil, fmt.Errorf("could not find context proto type name: %s", vd.ContextProto) + typeName := c.ContextVariable.TypeName + if typeName == "" { + return nil, errors.New("invalid context variable, must set type name field") + } + if _, found := provider.FindStructType(typeName); !found { + return nil, fmt.Errorf("could not find context proto type name: %s", typeName) } // Attempt to instantiate the proto in order to reflect to its descriptor - msg := baseEnv.CELTypeProvider().NewValue(vd.ContextProto, map[string]ref.Val{}) + msg := provider.NewValue(typeName, map[string]ref.Val{}) pbMsg, ok := msg.Value().(proto.Message) if !ok { return nil, fmt.Errorf("type name was not a protobuf: %T", msg.Value()) } - return cel.DeclareContextProto(pbMsg.ProtoReflect().Descriptor()), nil + envOpts = append(envOpts, cel.DeclareContextProto(pbMsg.ProtoReflect().Descriptor())) } - return nil, errors.New("invalid variable, must set 'name' or 'context_proto' field") -} -// TypeDecl represents a YAML serializable CEL type reference. -type TypeDecl struct { - TypeName string `yaml:"type_name"` - Params []*TypeDecl `yaml:"params"` - IsTypeParam bool `yaml:"is_type_param"` -} - -// AsCELType converts the TypeDecl value to a cel.Type value using the input base environment. -// -// All extension types referenced by name within the `TypeDecl.TypeName` field must be configured -// within the base CEL environment argument. -func (td *TypeDecl) AsCELType(baseEnv *cel.Env) (*cel.Type, error) { - var err error - switch td.TypeName { - case "dyn": - return cel.DynType, nil - case "map": - if len(td.Params) == 2 { - kt, err := td.Params[0].AsCELType(baseEnv) + if len(c.Variables) != 0 { + vars := make([]*decls.VariableDecl, 0, len(c.Variables)) + for _, v := range c.Variables { + vDef, err := v.AsCELVariable(provider) if err != nil { return nil, err } - vt, err := td.Params[1].AsCELType(baseEnv) - if err != nil { - return nil, err - } - return cel.MapType(kt, vt), nil + vars = append(vars, vDef) } - return nil, fmt.Errorf("map type has unexpected param count: %d", len(td.Params)) - case "list": - if len(td.Params) == 1 { - et, err := td.Params[0].AsCELType(baseEnv) - if err != nil { - return nil, err - } - return cel.ListType(et), nil - } - return nil, fmt.Errorf("list type has unexpected param count: %d", len(td.Params)) - default: - if td.IsTypeParam { - return cel.TypeParamType(td.TypeName), nil - } - if msgType, found := baseEnv.CELTypeProvider().FindStructType(td.TypeName); found { - // First parameter is the type name. - return msgType.Parameters()[0], nil - } - t, found := baseEnv.CELTypeProvider().FindIdent(td.TypeName) - if !found { - return nil, fmt.Errorf("undefined type name: %v", td.TypeName) - } - _, ok := t.(*cel.Type) - if ok && len(td.Params) == 0 { - return t.(*cel.Type), nil - } - params := make([]*cel.Type, len(td.Params)) - for i, p := range td.Params { - params[i], err = p.AsCELType(baseEnv) + envOpts = append(envOpts, cel.VariableDecls(vars...)) + } + if len(c.Functions) != 0 { + funcs := make([]*decls.FunctionDecl, 0, len(c.Functions)) + for _, f := range c.Functions { + fnDef, err := f.AsCELFunction(provider) if err != nil { return nil, err } + funcs = append(funcs, fnDef) } - return cel.OpaqueType(td.TypeName, params...), nil + envOpts = append(envOpts, cel.FunctionDecls(funcs...)) } -} - -// FunctionDecl represents a YAML serializable declaration of a CEL function. -type FunctionDecl struct { - Name string `yaml:"name"` - Overloads []*OverloadDecl `yaml:"overloads"` -} - -// AsEnvOption converts a FunctionDecl value into a cel.EnvOption using the input environment. -func (fd *FunctionDecl) AsEnvOption(baseEnv *cel.Env) (cel.EnvOption, error) { - overloads := make([]cel.FunctionOpt, len(fd.Overloads)) - var err error - for i, o := range fd.Overloads { - overloads[i], err = o.AsFunctionOption(baseEnv) + for _, e := range c.Extensions { + opt, err := extensionEnvOption(e) if err != nil { return nil, err } + envOpts = append(envOpts, opt) } - return cel.Function(fd.Name, overloads...), nil -} - -// OverloadDecl represents a YAML serializable declaration of a CEL function overload. -type OverloadDecl struct { - OverloadID string `yaml:"id"` - Target *TypeDecl `yaml:"target"` - Args []*TypeDecl `yaml:"args"` - Return *TypeDecl `yaml:"return"` + return envOpts, nil } -// AsFunctionOption converts an OverloadDecl value into a cel.FunctionOpt using the input environment. -func (od *OverloadDecl) AsFunctionOption(baseEnv *cel.Env) (cel.FunctionOpt, error) { - args := make([]*cel.Type, len(od.Args)) - var err error - for i, a := range od.Args { - args[i], err = a.AsCELType(baseEnv) - if err != nil { - return nil, err - } - } - - if od.Return == nil { - return nil, fmt.Errorf("missing return type on overload: %v", od.OverloadID) +// extensionEnvOption converts an ExtensionConfig value to a CEL environment option. +func extensionEnvOption(ec *env.Extension) (cel.EnvOption, error) { + fac, found := extFactories[ec.Name] + if !found { + return nil, fmt.Errorf("unrecognized extension: %s", ec.Name) } - result, err := od.Return.AsCELType(baseEnv) + // If the version is 'latest', set the version value to the max uint. + ver, err := ec.GetVersion() if err != nil { return nil, err } - if od.Target != nil { - t, err := od.Target.AsCELType(baseEnv) - if err != nil { - return nil, err - } - args = append([]*cel.Type{t}, args...) - return cel.MemberOverload(od.OverloadID, args, result), nil - } - return cel.Overload(od.OverloadID, args, result), nil + return fac(ver), nil } -var extFactories = map[string]ExtensionFactory{ +// extensionFactory accepts a version and produces a CEL environment associated with the versioned extension. +type extensionFactory func(uint32) cel.EnvOption + +var extFactories = map[string]extensionFactory{ "bindings": func(version uint32) cel.EnvOption { return ext.Bindings(ext.BindingsVersion(version)) }, diff --git a/policy/config_test.go b/policy/config_test.go index 68196327a..df28990a7 100644 --- a/policy/config_test.go +++ b/policy/config_test.go @@ -15,11 +15,10 @@ package policy import ( - "strings" "testing" "github.com/google/cel-go/cel" - "github.com/google/cel-go/ext" + "github.com/google/cel-go/common/env" "gopkg.in/yaml.v3" @@ -104,7 +103,7 @@ variables: } for _, tst := range tests { c := parseConfigYaml(t, tst) - _, err := c.AsEnvOptions(baseEnv) + _, err := c.AsEnvOptions(baseEnv.CELTypeProvider()) if err != nil { t.Errorf("AsEnvOptions() generated error: %v", err) } @@ -166,13 +165,13 @@ variables: type_name: "map" params: - type_name: "string" - - type_name: "optional"`, - err: "invalid variable type for 'bad_map_type_param': undefined type name: optional", + - type_name: "invalid_opaque_type"`, + err: "invalid variable type for 'bad_map_type_param': undefined type name: invalid_opaque_type", }, { config: ` -variables: - - context_proto: "bad.proto.MessageType" +context_variable: + type_name: "bad.proto.MessageType" `, err: "could not find context proto type name: bad.proto.MessageType", }, @@ -181,7 +180,7 @@ variables: variables: - type: type_name: "no variable name"`, - err: "invalid variable, must set 'name' or 'context_proto' field", + err: "invalid variable, must declare a name", }, { @@ -235,60 +234,17 @@ functions: } for _, tst := range tests { c := parseConfigYaml(t, tst.config) - _, err := c.AsEnvOptions(baseEnv) + _, err := c.AsEnvOptions(baseEnv.CELTypeProvider()) if err == nil || err.Error() != tst.err { t.Errorf("AsEnvOptions() got error: %v, wanted %s", err, tst.err) } } } -func TestExtensionResolver(t *testing.T) { - ext := ` -extensions: - - name: "math" - - name: "strings_en_US" - version: 1` - - baseEnv, err := cel.NewEnv() - if err != nil { - t.Fatalf("cel.NewEnv() failed: %v", err) - } - c := parseConfigYaml(t, ext) - for _, e := range c.Extensions { - e.ExtensionResolver = stringLocaleResolver{} - } - opts, err := c.AsEnvOptions(baseEnv) - if err != nil { - t.Errorf("AsEnvOptions() generated error: %v", err) - } - extEnv, err := baseEnv.Extend(opts...) - if err != nil { - t.Fatalf("baseEnv.Extend() failed: %v", err) - } - if !extEnv.HasLibrary("cel.lib.ext.strings") || !extEnv.HasLibrary("cel.lib.ext.math") { - t.Error("extended env did not contain standardized or custom extensions") - } -} - func parseConfigYaml(t *testing.T, doc string) *Config { - config := &Config{} + config := &env.Config{} if err := yaml.Unmarshal([]byte(doc), config); err != nil { t.Fatalf("yaml.Unmarshal(%q) failed: %v", doc, err) } - return config -} - -type stringLocaleResolver struct{} - -func (stringLocaleResolver) ResolveExtension(name string) (ExtensionFactory, bool) { - parts := strings.SplitN(name, "_", 2) - if len(parts) == 2 && parts[0] == "strings" { - return func(version uint32) cel.EnvOption { - return ext.Strings( - ext.StringsLocale(parts[1]), - ext.StringsVersion(version), - ) - }, true - } - return nil, false + return NewConfig(config) } diff --git a/policy/go.mod b/policy/go.mod index abc045142..6c996e61a 100644 --- a/policy/go.mod +++ b/policy/go.mod @@ -9,7 +9,7 @@ require ( ) require ( - cel.dev/expr v0.18.0 // indirect + cel.dev/expr v0.19.1 // indirect github.com/antlr4-go/antlr/v4 v4.13.1 // indirect github.com/stoewer/go-strcase v1.3.0 // indirect golang.org/x/exp v0.0.0-20240823005443-9b4947da3948 // indirect diff --git a/policy/go.sum b/policy/go.sum index b9aa3f0a3..a0f0bfd91 100644 --- a/policy/go.sum +++ b/policy/go.sum @@ -1,5 +1,5 @@ -cel.dev/expr v0.18.0 h1:CJ6drgk+Hf96lkLikr4rFf19WrU0BOWEihyZnI2TAzo= -cel.dev/expr v0.18.0/go.mod h1:MrpN08Q+lEBs+bGYdLxxHkZoUSsCp0nSKTs0nTymJgw= +cel.dev/expr v0.19.1 h1:NciYrtDRIR0lNCnH1LFJegdjspNx9fI59O7TWcua/W4= +cel.dev/expr v0.19.1/go.mod h1:MrpN08Q+lEBs+bGYdLxxHkZoUSsCp0nSKTs0nTymJgw= github.com/antlr4-go/antlr/v4 v4.13.1 h1:SqQKkuVZ+zWkMMNkjy5FZe5mr5WURWnlpmOuzYWrPrQ= github.com/antlr4-go/antlr/v4 v4.13.1/go.mod h1:GKmUxMtwp6ZgGwZSva4eWPC5mS6vUAmOABFgjdkM7Nw= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/policy/helper_test.go b/policy/helper_test.go index f65b19c68..ba612fc08 100644 --- a/policy/helper_test.go +++ b/policy/helper_test.go @@ -20,6 +20,7 @@ import ( "testing" "github.com/google/cel-go/cel" + "github.com/google/cel-go/common/env" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" @@ -283,12 +284,12 @@ func readPolicyConfig(t testing.TB, fileName string) *Config { if err != nil { t.Fatalf("os.ReadFile(%s) failed: %v", fileName, err) } - config := &Config{} + config := &env.Config{} err = yaml.Unmarshal(testCaseBytes, config) if err != nil { log.Fatalf("yaml.Unmarshal(%s) error: %v", fileName, err) } - return config + return NewConfig(config) } func readTestSuite(t testing.TB, fileName string) *TestSuite { diff --git a/policy/parser.go b/policy/parser.go index 7b34b246a..65ada13fc 100644 --- a/policy/parser.go +++ b/policy/parser.go @@ -484,7 +484,7 @@ func (p *parserImpl) parseYAML(src *Source) *Policy { var docNode yaml.Node err := sourceToYAML(src, &docNode) if err != nil { - p.iss.ReportErrorAtID(0, err.Error()) + p.iss.ReportErrorAtID(0, "%s", err.Error()) return nil } // Entry point always has a single Content node diff --git a/policy/testdata/context_pb/config.yaml b/policy/testdata/context_pb/config.yaml index 80dd1dd0a..53ea95425 100644 --- a/policy/testdata/context_pb/config.yaml +++ b/policy/testdata/context_pb/config.yaml @@ -17,5 +17,5 @@ container: "google.expr.proto3" extensions: - name: "strings" version: 2 -variables: - - context_proto: "google.expr.proto3.test.TestAllTypes" +context_variable: + type_name: "google.expr.proto3.test.TestAllTypes" diff --git a/policy/testdata/k8s/config.yaml b/policy/testdata/k8s/config.yaml index aa1adb2b0..15a32b535 100644 --- a/policy/testdata/k8s/config.yaml +++ b/policy/testdata/k8s/config.yaml @@ -18,16 +18,13 @@ extensions: version: 2 variables: - name: "resource.labels" - type: - type_name: "map" - params: - - type_name: "string" - - type_name: "string" + type_name: "map" + params: + - type_name: "string" + - type_name: "string" - name: "resource.containers" - type: - type_name: "list" - params: - - type_name: "string" + type_name: "list" + params: + - type_name: "string" - name: "resource.namespace" - type: - type_name: "string" + type_name: "string" diff --git a/policy/testdata/nested_rule/config.yaml b/policy/testdata/nested_rule/config.yaml index bfd94b33c..e45466b96 100644 --- a/policy/testdata/nested_rule/config.yaml +++ b/policy/testdata/nested_rule/config.yaml @@ -15,8 +15,7 @@ name: "nested_rule" variables: - name: "resource" - type: - type_name: "map" - params: - - type_name: "string" - - type_name: "dyn" + type_name: "map" + params: + - type_name: "string" + - type_name: "dyn" diff --git a/policy/testdata/pb/config.yaml b/policy/testdata/pb/config.yaml index 963996f1a..c036d88af 100644 --- a/policy/testdata/pb/config.yaml +++ b/policy/testdata/pb/config.yaml @@ -19,5 +19,4 @@ extensions: version: 2 variables: - name: "spec" - type: - type_name: "google.expr.proto3.test.TestAllTypes" + type_name: "google.expr.proto3.test.TestAllTypes" diff --git a/policy/testdata/required_labels/config.yaml b/policy/testdata/required_labels/config.yaml index 1fae24d46..f9081478a 100644 --- a/policy/testdata/required_labels/config.yaml +++ b/policy/testdata/required_labels/config.yaml @@ -20,14 +20,12 @@ extensions: - name: "two-var-comprehensions" variables: - name: "spec" - type: - type_name: "map" - params: - - type_name: "string" - - type_name: "dyn" + type_name: "map" + params: + - type_name: "string" + - type_name: "dyn" - name: "resource" - type: - type_name: "map" - params: - - type_name: "string" - - type_name: "dyn" + type_name: "map" + params: + - type_name: "string" + - type_name: "dyn" diff --git a/policy/testdata/restricted_destinations/base_config.yaml b/policy/testdata/restricted_destinations/base_config.yaml index f4c24cf03..2aae385ca 100644 --- a/policy/testdata/restricted_destinations/base_config.yaml +++ b/policy/testdata/restricted_destinations/base_config.yaml @@ -18,28 +18,22 @@ extensions: - name: "sets" variables: - name: "destination.ip" - type: - type_name: "string" + type_name: "string" - name: "origin.ip" - type: - type_name: "string" + type_name: "string" - name: "spec.restricted_destinations" - type: - type_name: "list" - params: - - type_name: "string" + type_name: "list" + params: + - type_name: "string" - name: "spec.origin" - type: - type_name: "string" + type_name: "string" - name: "request" - type: - type_name: "map" - params: - - type_name: "string" - - type_name: "dyn" + type_name: "map" + params: + - type_name: "string" + - type_name: "dyn" - name: "resource" - type: - type_name: "map" - params: - - type_name: "string" - - type_name: "dyn" + type_name: "map" + params: + - type_name: "string" + - type_name: "dyn" diff --git a/policy/testdata/restricted_destinations/config.yaml b/policy/testdata/restricted_destinations/config.yaml index f6f24bf1e..3e8b78908 100644 --- a/policy/testdata/restricted_destinations/config.yaml +++ b/policy/testdata/restricted_destinations/config.yaml @@ -18,31 +18,25 @@ extensions: - name: "sets" variables: - name: "destination.ip" - type: - type_name: "string" + type_name: "string" - name: "origin.ip" - type: - type_name: "string" + type_name: "string" - name: "spec.restricted_destinations" - type: - type_name: "list" - params: - - type_name: "string" + type_name: "list" + params: + - type_name: "string" - name: "spec.origin" - type: - type_name: "string" + type_name: "string" - name: "request" - type: - type_name: "map" - params: - - type_name: "string" - - type_name: "dyn" + type_name: "map" + params: + - type_name: "string" + - type_name: "dyn" - name: "resource" - type: - type_name: "map" - params: - - type_name: "string" - - type_name: "dyn" + type_name: "map" + params: + - type_name: "string" + - type_name: "dyn" functions: - name: "locationCode" overloads: From af0bf8e86cf5b2f6d48485dab693a4280f7ee639 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Wed, 12 Feb 2025 15:05:15 -0800 Subject: [PATCH 10/46] Support for cel.Env conversion to YAML-serializable config (#1128) * Support for cel.Env conversion to an env.Config object * Backed out change remove stdlib types as variables --- cel/cel_test.go | 3 +- cel/env.go | 133 +- cel/env_test.go | 115 + cel/library.go | 59 +- cel/options.go | 25 +- checker/decls/decls.go | 4 +- common/containers/container.go | 14 +- common/containers/container_test.go | 8 +- common/decls/decls.go | 5 + common/env/BUILD.bazel | 33 +- common/env/env.go | 274 +- common/env/env_test.go | 485 +++- common/env/testdata/context_env.yaml | 57 + common/env/testdata/extended_env.yaml | 40 + common/env/testdata/subset_env.yaml | 39 + go.mod | 1 + go.sum | 3 + policy/config.go | 15 +- vendor/gopkg.in/yaml.v3/LICENSE | 50 + vendor/gopkg.in/yaml.v3/NOTICE | 13 + vendor/gopkg.in/yaml.v3/README.md | 150 ++ vendor/gopkg.in/yaml.v3/apic.go | 747 ++++++ vendor/gopkg.in/yaml.v3/decode.go | 1000 ++++++++ vendor/gopkg.in/yaml.v3/emitterc.go | 2020 +++++++++++++++ vendor/gopkg.in/yaml.v3/encode.go | 577 +++++ vendor/gopkg.in/yaml.v3/parserc.go | 1258 ++++++++++ vendor/gopkg.in/yaml.v3/readerc.go | 434 ++++ vendor/gopkg.in/yaml.v3/resolve.go | 326 +++ vendor/gopkg.in/yaml.v3/scannerc.go | 3038 +++++++++++++++++++++++ vendor/gopkg.in/yaml.v3/sorter.go | 134 + vendor/gopkg.in/yaml.v3/writerc.go | 48 + vendor/gopkg.in/yaml.v3/yaml.go | 698 ++++++ vendor/gopkg.in/yaml.v3/yamlh.go | 807 ++++++ vendor/gopkg.in/yaml.v3/yamlprivateh.go | 198 ++ vendor/modules.txt | 3 + 35 files changed, 12653 insertions(+), 161 deletions(-) create mode 100644 common/env/testdata/context_env.yaml create mode 100644 common/env/testdata/extended_env.yaml create mode 100644 common/env/testdata/subset_env.yaml create mode 100644 vendor/gopkg.in/yaml.v3/LICENSE create mode 100644 vendor/gopkg.in/yaml.v3/NOTICE create mode 100644 vendor/gopkg.in/yaml.v3/README.md create mode 100644 vendor/gopkg.in/yaml.v3/apic.go create mode 100644 vendor/gopkg.in/yaml.v3/decode.go create mode 100644 vendor/gopkg.in/yaml.v3/emitterc.go create mode 100644 vendor/gopkg.in/yaml.v3/encode.go create mode 100644 vendor/gopkg.in/yaml.v3/parserc.go create mode 100644 vendor/gopkg.in/yaml.v3/readerc.go create mode 100644 vendor/gopkg.in/yaml.v3/resolve.go create mode 100644 vendor/gopkg.in/yaml.v3/scannerc.go create mode 100644 vendor/gopkg.in/yaml.v3/sorter.go create mode 100644 vendor/gopkg.in/yaml.v3/writerc.go create mode 100644 vendor/gopkg.in/yaml.v3/yaml.go create mode 100644 vendor/gopkg.in/yaml.v3/yamlh.go create mode 100644 vendor/gopkg.in/yaml.v3/yamlprivateh.go diff --git a/cel/cel_test.go b/cel/cel_test.go index e85c5786f..77f35d549 100644 --- a/cel/cel_test.go +++ b/cel/cel_test.go @@ -2234,8 +2234,7 @@ func TestDefaultUTCTimeZoneError(t *testing.T) { || x.getMilliseconds('Am/Ph') == 1 `, map[string]any{ "x": time.Unix(7506, 1000000).Local(), - }, - ) + }) if err == nil { t.Fatalf("prg.Eval() got %v wanted error", out) } diff --git a/cel/env.go b/cel/env.go index e3de439de..16531bb02 100644 --- a/cel/env.go +++ b/cel/env.go @@ -16,6 +16,8 @@ package cel import ( "errors" + "fmt" + "math" "sync" "github.com/google/cel-go/checker" @@ -24,12 +26,15 @@ import ( celast "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/containers" "github.com/google/cel-go/common/decls" + "github.com/google/cel-go/common/env" + "github.com/google/cel-go/common/stdlib" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/interpreter" "github.com/google/cel-go/parser" exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" + "google.golang.org/protobuf/reflect/protoreflect" ) // Source interface representing a user-provided expression. @@ -127,12 +132,13 @@ type Env struct { Container *containers.Container variables []*decls.VariableDecl functions map[string]*decls.FunctionDecl - macros []parser.Macro + macros []Macro + contextProto protoreflect.MessageDescriptor adapter types.Adapter provider types.Provider features map[int]bool appliedFeatures map[int]bool - libraries map[string]bool + libraries map[string]SingletonLibrary validators []ASTValidator costOptions []checker.CostOption @@ -151,6 +157,115 @@ type Env struct { progOpts []ProgramOption } +// ToConfig produces a YAML-serializable env.Config object from the given environment. +// +// The serialized configuration value is intended to represent a baseline set of config +// options which could be used as input to an EnvOption to configure the majority of the +// environment from a file. +// +// Note: validators, features, flags, and safe-guard settings are not yet supported by +// the serialize method. Since optimizers are a separate construct from the environment +// and the standard expression components (parse, check, evalute), they are also not +// supported by the serialize method. +func (e *Env) ToConfig(name string) (*env.Config, error) { + conf := env.NewConfig(name) + // Container settings + if e.Container != containers.DefaultContainer { + conf.SetContainer(e.Container.Name()) + } + for _, typeName := range e.Container.AliasSet() { + conf.AddImports(env.NewImport(typeName)) + } + + libOverloads := map[string][]string{} + for libName, lib := range e.libraries { + // Track the options which have been configured by a library and + // then diff the library version against the configured function + // to detect incremental overloads or rewrites. + libEnv, _ := NewCustomEnv() + libEnv, _ = Lib(lib)(libEnv) + for fnName, fnDecl := range libEnv.Functions() { + if len(fnDecl.OverloadDecls()) == 0 { + continue + } + overloads, exist := libOverloads[fnName] + if !exist { + overloads = make([]string, 0, len(fnDecl.OverloadDecls())) + } + for _, o := range fnDecl.OverloadDecls() { + overloads = append(overloads, o.ID()) + } + libOverloads[fnName] = overloads + } + subsetLib, canSubset := lib.(LibrarySubsetter) + alias := "" + if aliasLib, canAlias := lib.(LibraryAliaser); canAlias { + alias = aliasLib.LibraryAlias() + libName = alias + } + if libName == "stdlib" && canSubset { + conf.SetStdLib(subsetLib.LibrarySubset()) + continue + } + version := uint32(math.MaxUint32) + if versionLib, isVersioned := lib.(LibraryVersioner); isVersioned { + version = versionLib.LibraryVersion() + } + conf.AddExtensions(env.NewExtension(libName, version)) + } + + // If this is a custom environment without the standard env, mark the stdlib as disabled. + if conf.StdLib == nil && !e.HasLibrary("cel.lib.std") { + conf.SetStdLib(env.NewLibrarySubset().SetDisabled(true)) + } + + // Serialize the variables + vars := make([]*decls.VariableDecl, 0, len(e.Variables())) + stdTypeVars := map[string]*decls.VariableDecl{} + for _, v := range stdlib.Types() { + stdTypeVars[v.Name()] = v + } + for _, v := range e.Variables() { + if _, isStdType := stdTypeVars[v.Name()]; isStdType { + continue + } + vars = append(vars, v) + } + if e.contextProto != nil { + conf.SetContextVariable(env.NewContextVariable(string(e.contextProto.FullName()))) + skipVariables := map[string]bool{} + fields := e.contextProto.Fields() + for i := 0; i < fields.Len(); i++ { + field := fields.Get(i) + variable, err := fieldToVariable(field) + if err != nil { + return nil, fmt.Errorf("could not serialize context field variable %q, reason: %w", field.FullName(), err) + } + skipVariables[variable.Name()] = true + } + for _, v := range vars { + if _, found := skipVariables[v.Name()]; !found { + conf.AddVariableDecls(v) + } + } + } else { + conf.AddVariableDecls(vars...) + } + + // Serialize functions which are distinct from the ones configured by libraries. + for fnName, fnDecl := range e.Functions() { + if excludedOverloads, found := libOverloads[fnName]; found { + if newDecl := fnDecl.Subset(decls.ExcludeOverloads(excludedOverloads...)); newDecl != nil { + conf.AddFunctionDecls(newDecl) + } + } else { + conf.AddFunctionDecls(fnDecl) + } + } + + return conf, nil +} + // NewEnv creates a program environment configured with the standard library of CEL functions and // macros. The Env value returned can parse and check any CEL program which builds upon the core // features documented in the CEL specification. @@ -194,7 +309,7 @@ func NewCustomEnv(opts ...EnvOption) (*Env, error) { provider: registry, features: map[int]bool{}, appliedFeatures: map[int]bool{}, - libraries: map[string]bool{}, + libraries: map[string]SingletonLibrary{}, validators: []ASTValidator{}, progOpts: []ProgramOption{}, costOptions: []checker.CostOption{}, @@ -362,7 +477,7 @@ func (e *Env) Extend(opts ...EnvOption) (*Env, error) { for k, v := range e.functions { funcsCopy[k] = v } - libsCopy := make(map[string]bool, len(e.libraries)) + libsCopy := make(map[string]SingletonLibrary, len(e.libraries)) for k, v := range e.libraries { libsCopy[k] = v } @@ -376,6 +491,7 @@ func (e *Env) Extend(opts ...EnvOption) (*Env, error) { variables: varsCopy, functions: funcsCopy, macros: macsCopy, + contextProto: e.contextProto, progOpts: progOptsCopy, adapter: adapter, features: featuresCopy, @@ -399,8 +515,8 @@ func (e *Env) HasFeature(flag int) bool { // HasLibrary returns whether a specific SingletonLibrary has been configured in the environment. func (e *Env) HasLibrary(libName string) bool { - configured, exists := e.libraries[libName] - return exists && configured + _, exists := e.libraries[libName] + return exists } // Libraries returns a list of SingletonLibrary that have been configured in the environment. @@ -423,6 +539,11 @@ func (e *Env) Functions() map[string]*decls.FunctionDecl { return e.functions } +// Variables returns the set of variables associated with the environment. +func (e *Env) Variables() []*decls.VariableDecl { + return e.variables +} + // HasValidator returns whether a specific ASTValidator has been configured in the environment. func (e *Env) HasValidator(name string) bool { for _, v := range e.validators { diff --git a/cel/env_test.go b/cel/env_test.go index 3fbdeac00..64e1f2873 100644 --- a/cel/env_test.go +++ b/cel/env_test.go @@ -16,11 +16,13 @@ package cel import ( "fmt" + "math" "reflect" "sync" "testing" "github.com/google/cel-go/common" + "github.com/google/cel-go/common/env" "github.com/google/cel-go/common/operators" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" @@ -302,6 +304,119 @@ func TestFunctions(t *testing.T) { } } +func TestEnvToConfig(t *testing.T) { + tests := []struct { + name string + opts []EnvOption + wantConfig *env.Config + }{ + { + name: "std env", + wantConfig: env.NewConfig("std env"), + }, + { + name: "std env - container", + opts: []EnvOption{ + Container("example.container"), + }, + wantConfig: env.NewConfig("std env - container").SetContainer("example.container"), + }, + { + name: "std env - aliases", + opts: []EnvOption{ + Abbrevs("example.type.name"), + }, + wantConfig: env.NewConfig("std env - aliases").AddImports(env.NewImport("example.type.name")), + }, + { + name: "std env disabled", + opts: []EnvOption{ + func(*Env) (*Env, error) { + return NewCustomEnv() + }, + }, + wantConfig: env.NewConfig("std env disabled").SetStdLib( + env.NewLibrarySubset().SetDisabled(true)), + }, + { + name: "std env - with variable", + opts: []EnvOption{ + Variable("var", IntType), + }, + wantConfig: env.NewConfig("std env - with variable").AddVariables(env.NewVariable("var", env.NewTypeDesc("int"))), + }, + { + name: "std env - with function", + opts: []EnvOption{Function("hello", Overload("hello_string", []*Type{StringType}, StringType))}, + wantConfig: env.NewConfig("std env - with function").AddFunctions( + env.NewFunction("hello", []*env.Overload{ + env.NewOverload("hello_string", + []*env.TypeDesc{env.NewTypeDesc("string")}, env.NewTypeDesc("string"))}, + )), + }, + { + name: "optional lib", + opts: []EnvOption{ + OptionalTypes(), + }, + wantConfig: env.NewConfig("optional lib").AddExtensions(env.NewExtension("optional", math.MaxUint32)), + }, + { + name: "optional lib - versioned", + opts: []EnvOption{ + OptionalTypes(OptionalTypesVersion(1)), + }, + wantConfig: env.NewConfig("optional lib - versioned").AddExtensions(env.NewExtension("optional", 1)), + }, + { + name: "optional lib - alt last()", + opts: []EnvOption{ + OptionalTypes(), + Function("last", MemberOverload("string_last", []*Type{StringType}, StringType)), + }, + wantConfig: env.NewConfig("optional lib - alt last()"). + AddExtensions(env.NewExtension("optional", math.MaxUint32)). + AddFunctions(env.NewFunction("last", []*env.Overload{ + env.NewMemberOverload("string_last", env.NewTypeDesc("string"), []*env.TypeDesc{}, env.NewTypeDesc("string")), + })), + }, + { + name: "context proto - with extra variable", + opts: []EnvOption{ + DeclareContextProto((&proto3pb.TestAllTypes{}).ProtoReflect().Descriptor()), + Variable("extra", StringType), + }, + wantConfig: env.NewConfig("context proto - with extra variable"). + SetContextVariable(env.NewContextVariable("google.expr.proto3.test.TestAllTypes")). + AddVariables(env.NewVariable("extra", env.NewTypeDesc("string"))), + }, + { + name: "context proto", + opts: []EnvOption{ + DeclareContextProto((&proto3pb.TestAllTypes{}).ProtoReflect().Descriptor()), + }, + wantConfig: env.NewConfig("context proto").SetContextVariable(env.NewContextVariable("google.expr.proto3.test.TestAllTypes")), + }, + } + + for _, tst := range tests { + tc := tst + t.Run(tc.name, func(t *testing.T) { + e, err := NewEnv(tc.opts...) + if err != nil { + t.Fatalf("NewEnv() failed: %v", err) + } + gotConfig, err := e.ToConfig(tc.name) + if err != nil { + t.Fatalf("ToConfig() failed: %v", err) + } + if !reflect.DeepEqual(gotConfig, tc.wantConfig) { + t.Errorf("e.Config() got %v, wanted %v", gotConfig, tc.wantConfig) + } + }) + } +} + func BenchmarkNewCustomEnvLazy(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { diff --git a/cel/library.go b/cel/library.go index 1d081852c..ebe05dc93 100644 --- a/cel/library.go +++ b/cel/library.go @@ -73,6 +73,23 @@ type SingletonLibrary interface { LibraryName() string } +// LibraryAliaser generates a simple named alias for the library, for use during environment serialization. +type LibraryAliaser interface { + LibraryAlias() string +} + +// LibrarySubsetter provides the subset description associated with the library, nil if not subset. +type LibrarySubsetter interface { + LibrarySubset() *env.LibrarySubset +} + +// LibraryVersioner provides a version number for the library. +// +// If not implemented, the library version will be flagged as 'latest' during environment serialization. +type LibraryVersioner interface { + LibraryVersion() uint32 +} + // Lib creates an EnvOption out of a Library, allowing libraries to be provided as functional args, // and to be linked to each other. func Lib(l Library) EnvOption { @@ -82,7 +99,7 @@ func Lib(l Library) EnvOption { if e.HasLibrary(singleton.LibraryName()) { return e, nil } - e.libraries[singleton.LibraryName()] = true + e.libraries[singleton.LibraryName()] = singleton } var err error for _, opt := range l.CompileOptions() { @@ -100,6 +117,10 @@ func Lib(l Library) EnvOption { type StdLibOption func(*stdLibrary) *stdLibrary // StdLibSubset configures the standard library to use a subset of its functions and macros. +// +// Since the StdLib is a singleton library, only the first instance of the StdLib() environment options +// will be configured on the environment which means only the StdLibSubset() initially configured with +// the library will be used. func StdLibSubset(subset *env.LibrarySubset) StdLibOption { return func(lib *stdLibrary) *stdLibrary { lib.subset = subset @@ -127,6 +148,21 @@ func (*stdLibrary) LibraryName() string { return "cel.lib.std" } +// LibraryAlias returns the simple name of the library. +func (*stdLibrary) LibraryAlias() string { + return "stdlib" +} + +// LibraryVersion returns the version of the library. +func (*stdLibrary) LibraryVersion() uint32 { + return math.MaxUint32 +} + +// LibrarySubset returns the env.LibrarySubset definition associated with the CEL Library. +func (lib *stdLibrary) LibrarySubset() *env.LibrarySubset { + return lib.subset +} + // CompileOptions returns options for the standard CEL function declarations and macros. func (lib *stdLibrary) CompileOptions() []EnvOption { funcs := stdlib.Functions() @@ -162,6 +198,10 @@ func (lib *stdLibrary) CompileOptions() []EnvOption { } return e, nil }, + func(e *Env) (*Env, error) { + e.variables = append(e.variables, stdlib.Types()...) + return e, nil + }, Macros(macros...), } } @@ -358,10 +398,20 @@ func OptionalTypesVersion(version uint32) OptionalTypesOption { } // LibraryName implements the SingletonLibrary interface method. -func (lib *optionalLib) LibraryName() string { +func (*optionalLib) LibraryName() string { return "cel.lib.optional" } +// LibraryAlias returns the simple name of the library. +func (*optionalLib) LibraryAlias() string { + return "optional" +} + +// LibraryVersion returns the version of the library. +func (lib *optionalLib) LibraryVersion() uint32 { + return lib.version +} + // CompileOptions implements the Library interface method. func (lib *optionalLib) CompileOptions() []EnvOption { paramTypeK := TypeParamType("K") @@ -492,6 +542,11 @@ func (lib *optionalLib) ProgramOptions() []ProgramOption { } } +// Version returns the current version of the library. +func (lib *optionalLib) Version() uint32 { + return lib.version +} + func optMap(meh MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *Error) { varIdent := args[0] varName := "" diff --git a/cel/options.go b/cel/options.go index 82b6c8d9f..8b170d5de 100644 --- a/cel/options.go +++ b/cel/options.go @@ -25,6 +25,7 @@ import ( "github.com/google/cel-go/checker" "github.com/google/cel-go/common/containers" + "github.com/google/cel-go/common/decls" "github.com/google/cel-go/common/functions" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/pb" @@ -534,7 +535,7 @@ func fieldToCELType(field protoreflect.FieldDescriptor) (*Type, error) { return nil, fmt.Errorf("field %s type %s not implemented", field.FullName(), field.Kind().String()) } -func fieldToVariable(field protoreflect.FieldDescriptor) (EnvOption, error) { +func fieldToVariable(field protoreflect.FieldDescriptor) (*decls.VariableDecl, error) { name := string(field.Name()) if field.IsMap() { mapKey := field.MapKey() @@ -547,20 +548,20 @@ func fieldToVariable(field protoreflect.FieldDescriptor) (EnvOption, error) { if err != nil { return nil, err } - return Variable(name, MapType(keyType, valueType)), nil + return decls.NewVariable(name, MapType(keyType, valueType)), nil } if field.IsList() { elemType, err := fieldToCELType(field) if err != nil { return nil, err } - return Variable(name, ListType(elemType)), nil + return decls.NewVariable(name, ListType(elemType)), nil } celType, err := fieldToCELType(field) if err != nil { return nil, err } - return Variable(name, celType), nil + return decls.NewVariable(name, celType), nil } // DeclareContextProto returns an option to extend CEL environment with declarations from the given context proto. @@ -568,17 +569,25 @@ func fieldToVariable(field protoreflect.FieldDescriptor) (EnvOption, error) { // https://github.com/google/cel-spec/blob/master/doc/langdef.md#evaluation-environment func DeclareContextProto(descriptor protoreflect.MessageDescriptor) EnvOption { return func(e *Env) (*Env, error) { + if e.contextProto != nil { + return nil, fmt.Errorf("context proto already declared as %q, got %q", + e.contextProto.FullName(), descriptor.FullName()) + } + e.contextProto = descriptor fields := descriptor.Fields() + vars := make([]*decls.VariableDecl, 0, fields.Len()) for i := 0; i < fields.Len(); i++ { field := fields.Get(i) variable, err := fieldToVariable(field) if err != nil { return nil, err } - e, err = variable(e) - if err != nil { - return nil, err - } + vars = append(vars, variable) + } + var err error + e, err = VariableDecls(vars...)(e) + if err != nil { + return nil, err } return Types(dynamicpb.NewMessage(descriptor))(e) } diff --git a/checker/decls/decls.go b/checker/decls/decls.go index c0e5de469..ef1c4bbb4 100644 --- a/checker/decls/decls.go +++ b/checker/decls/decls.go @@ -231,7 +231,5 @@ func NewWrapperType(wrapped *exprpb.Type) *exprpb.Type { // TODO: return an error panic("Wrapped type must be a primitive") } - return &exprpb.Type{ - TypeKind: &exprpb.Type_Wrapper{ - Wrapper: primitive}} + return &exprpb.Type{TypeKind: &exprpb.Type_Wrapper{Wrapper: primitive}} } diff --git a/common/containers/container.go b/common/containers/container.go index 3097a3f78..fc146b6fc 100644 --- a/common/containers/container.go +++ b/common/containers/container.go @@ -63,9 +63,9 @@ func (c *Container) Extend(opts ...ContainerOption) (*Container, error) { } // Copy the name and aliases of the existing container. ext := &Container{name: c.Name()} - if len(c.aliasSet()) > 0 { - aliasSet := make(map[string]string, len(c.aliasSet())) - for k, v := range c.aliasSet() { + if len(c.AliasSet()) > 0 { + aliasSet := make(map[string]string, len(c.AliasSet())) + for k, v := range c.AliasSet() { aliasSet[k] = v } ext.aliases = aliasSet @@ -133,8 +133,8 @@ func (c *Container) ResolveCandidateNames(name string) []string { return append(candidates, name) } -// aliasSet returns the alias to fully-qualified name mapping stored in the container. -func (c *Container) aliasSet() map[string]string { +// AliasSet returns the alias to fully-qualified name mapping stored in the container. +func (c *Container) AliasSet() map[string]string { if c == nil || c.aliases == nil { return noAliases } @@ -160,7 +160,7 @@ func (c *Container) findAlias(name string) (string, bool) { simple = name[0:dot] qualifier = name[dot:] } - alias, found := c.aliasSet()[simple] + alias, found := c.AliasSet()[simple] if !found { return "", false } @@ -264,7 +264,7 @@ func aliasAs(kind, qualifiedName, alias string) ContainerOption { return nil, fmt.Errorf("%s must refer to a valid qualified name: %s", kind, qualifiedName) } - aliasRef, found := c.aliasSet()[alias] + aliasRef, found := c.AliasSet()[alias] if found { return nil, fmt.Errorf( "%s collides with existing reference: name=%s, %s=%s, existing=%s", diff --git a/common/containers/container_test.go b/common/containers/container_test.go index 06efd4198..e8cfb6844 100644 --- a/common/containers/container_test.go +++ b/common/containers/container_test.go @@ -186,8 +186,8 @@ func TestContainers_Extend_Alias(t *testing.T) { if err != nil { t.Fatal(err) } - if c.aliasSet()["alias"] != "test.alias" { - t.Errorf("got alias %v wanted 'test.alias'", c.aliasSet()) + if c.AliasSet()["alias"] != "test.alias" { + t.Errorf("got alias %v wanted 'test.alias'", c.AliasSet()) } c, err = c.Extend(Name("with.container")) if err != nil { @@ -196,8 +196,8 @@ func TestContainers_Extend_Alias(t *testing.T) { if c.Name() != "with.container" { t.Errorf("got container name %s, wanted 'with.container'", c.Name()) } - if c.aliasSet()["alias"] != "test.alias" { - t.Errorf("got alias %v wanted 'test.alias'", c.aliasSet()) + if c.AliasSet()["alias"] != "test.alias" { + t.Errorf("got alias %v wanted 'test.alias'", c.AliasSet()) } } diff --git a/common/decls/decls.go b/common/decls/decls.go index df05d2198..cec22707a 100644 --- a/common/decls/decls.go +++ b/common/decls/decls.go @@ -182,6 +182,8 @@ func ExcludeOverloads(overloadIDs ...string) OverloadSelector { } // Subset returns a new function declaration which contains only the overloads with the specified IDs. +// If the subset function contains no overloads, then nil is returned to indicate the function is not +// functional. func (f *FunctionDecl) Subset(selector OverloadSelector) *FunctionDecl { if f == nil { return nil @@ -195,6 +197,9 @@ func (f *FunctionDecl) Subset(selector OverloadSelector) *FunctionDecl { overloadOrdinals = append(overloadOrdinals, oID) } } + if len(overloads) == 0 { + return nil + } subset := &FunctionDecl{ name: f.Name(), overloads: overloads, diff --git a/common/env/BUILD.bazel b/common/env/BUILD.bazel index 148d49c14..0e7dae1d1 100644 --- a/common/env/BUILD.bazel +++ b/common/env/BUILD.bazel @@ -1,4 +1,18 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") package( default_visibility = ["//visibility:public"], @@ -16,3 +30,20 @@ go_library( "//common/types:go_default_library", ], ) + +go_test( + name = "go_default_test", + size = "small", + srcs = [ + "env_test.go", + ], + data = glob(["testdata/**"]), + embed = [":go_default_library"], + deps = [ + "//common/decls:go_default_library", + "//common/operators:go_default_library", + "//common/overloads:go_default_library", + "//common/types:go_default_library", + "@in_gopkg_yaml_v3//:go_default_library", + ], +) diff --git a/common/env/env.go b/common/env/env.go index 3595b388c..27e28cfd7 100644 --- a/common/env/env.go +++ b/common/env/env.go @@ -20,29 +20,26 @@ import ( "fmt" "math" "strconv" + "strings" "github.com/google/cel-go/common/decls" "github.com/google/cel-go/common/types" ) // NewConfig creates an instance of a YAML serializable CEL environment configuration. -func NewConfig() *Config { +func NewConfig(name string) *Config { return &Config{ - Imports: []*Import{}, - Extensions: []*Extension{}, - Variables: []*Variable{}, - Functions: []*Function{}, + Name: name, } } // Config represents a serializable form of the CEL environment configuration. // -// Note: custom validations, feature flags, and performance tuning parameters are -// not (yet) considered part of the core CEL environment configuration and should -// be managed separately until a common convention for such configuration settings -// can be developed. +// Note: custom validations, feature flags, and performance tuning parameters are not (yet) +// considered part of the core CEL environment configuration and should be managed separately +// until a common convention for such configuration settings is developed. type Config struct { - Name string `yaml:"name"` + Name string `yaml:"name,omitempty"` Description string `yaml:"description,omitempty"` Container string `yaml:"container,omitempty"` Imports []*Import `yaml:"imports,omitempty"` @@ -53,17 +50,111 @@ type Config struct { Functions []*Function `yaml:"functions,omitempty"` } +// SetContainer configures the container name for this configuration. +func (c *Config) SetContainer(container string) *Config { + c.Container = container + return c +} + +// AddVariableDecls adds one or more variables to the config, converting them to serializable values first. +// +// VariableDecl inputs are expected to be well-formed. +func (c *Config) AddVariableDecls(vars ...*decls.VariableDecl) *Config { + convVars := make([]*Variable, len(vars)) + for i, v := range vars { + if v == nil { + continue + } + convVars[i] = NewVariable(v.Name(), serializeTypeDesc(v.Type())) + } + return c.AddVariables(convVars...) +} + +// AddVariables adds one or more vairables to the config. +func (c *Config) AddVariables(vars ...*Variable) *Config { + c.Variables = append(c.Variables, vars...) + return c +} + +// SetContextVariable configures the ContextVariable for this configuration. +func (c *Config) SetContextVariable(ctx *ContextVariable) *Config { + c.ContextVariable = ctx + return c +} + +// AddFunctionDecls adds one or more functions to the config, converting them to serializable values first. +// +// FunctionDecl inputs are expected to be well-formed. +func (c *Config) AddFunctionDecls(funcs ...*decls.FunctionDecl) *Config { + convFuncs := make([]*Function, len(funcs)) + for i, fn := range funcs { + if fn == nil { + continue + } + overloads := make([]*Overload, 0, len(fn.OverloadDecls())) + for _, o := range fn.OverloadDecls() { + overloadID := o.ID() + args := make([]*TypeDesc, 0, len(o.ArgTypes())) + for _, a := range o.ArgTypes() { + args = append(args, serializeTypeDesc(a)) + } + ret := serializeTypeDesc(o.ResultType()) + if o.IsMemberFunction() { + overloads = append(overloads, NewMemberOverload(overloadID, args[0], args[1:], ret)) + } else { + overloads = append(overloads, NewOverload(overloadID, args, ret)) + } + } + convFuncs[i] = NewFunction(fn.Name(), overloads) + } + return c.AddFunctions(convFuncs...) +} + +// AddFunctions adds one or more functions to the config. +func (c *Config) AddFunctions(funcs ...*Function) *Config { + c.Functions = append(c.Functions, funcs...) + return c +} + +// SetStdLib configures the LibrarySubset for the standard library. +func (c *Config) SetStdLib(subset *LibrarySubset) *Config { + c.StdLib = subset + return c +} + +// AddImports appends a set of imports to the config. +func (c *Config) AddImports(imps ...*Import) *Config { + c.Imports = append(c.Imports, imps...) + return c +} + +// AddExtensions appends a set of extensions to the config. +func (c *Config) AddExtensions(exts ...*Extension) *Config { + c.Extensions = append(c.Extensions, exts...) + return c +} + +// NewImport returns a serializable import value from the qualified type name. +func NewImport(name string) *Import { + return &Import{Name: name} +} + // Import represents a type name that will be appreviated by its simple name using // the cel.Abbrevs() option. type Import struct { Name string `yaml:"name"` } +// NewVariable returns a serializable variable from a name and type definition +func NewVariable(name string, t *TypeDesc) *Variable { + return &Variable{Name: name, TypeDesc: t} +} + // Variable represents a typed variable declaration which will be published via the // cel.VariableDecls() option. type Variable struct { Name string `yaml:"name"` - Description string `yaml:"description"` + Description string `yaml:"description,omitempty"` // Type represents the type declaration for the variable. // @@ -109,6 +200,11 @@ func (vd *Variable) AsCELVariable(tp types.Provider) (*decls.VariableDecl, error return nil, fmt.Errorf("invalid variable '%s', no type specified", vd.Name) } +// NewContextVariable returns a serializable context variable with a specific type name. +func NewContextVariable(typeName string) *ContextVariable { + return &ContextVariable{TypeName: typeName} +} + // ContextVariable represents a structured message whose fields are to be treated as the top-level // variable identifiers within CEL expressions. type ContextVariable struct { @@ -117,11 +213,16 @@ type ContextVariable struct { TypeName string `yaml:"type_name"` } +// NewFunction creates a serializable function and overload set. +func NewFunction(name string, overloads []*Overload) *Function { + return &Function{Name: name, Overloads: overloads} +} + // Function represents the serializable format of a function and its overloads. type Function struct { Name string `yaml:"name"` - Description string `yaml:"description"` - Overloads []*Overload `yaml:"overloads"` + Description string `yaml:"description,omitempty"` + Overloads []*Overload `yaml:"overloads,omitempty"` } // AsCELFunction converts the serializable form of the Function into CEL environment declaration. @@ -146,13 +247,23 @@ func (fn *Function) AsCELFunction(tp types.Provider) (*decls.FunctionDecl, error return decls.NewFunction(fn.Name, overloads...) } +// NewOverload returns a new serializable representation of a global overload. +func NewOverload(id string, args []*TypeDesc, ret *TypeDesc) *Overload { + return &Overload{ID: id, Args: args, Return: ret} +} + +// NewMemberOverload returns a new serializable representation of a member (receiver) overload. +func NewMemberOverload(id string, target *TypeDesc, args []*TypeDesc, ret *TypeDesc) *Overload { + return &Overload{ID: id, Target: target, Args: args, Return: ret} +} + // Overload represents the serializable format of a function overload. type Overload struct { ID string `yaml:"id"` - Description string `yaml:"description"` - Target *TypeDesc `yaml:"target"` - Args []*TypeDesc `yaml:"args"` - Return *TypeDesc `yaml:"return"` + Description string `yaml:"description,omitempty"` + Target *TypeDesc `yaml:"target,omitempty"` + Args []*TypeDesc `yaml:"args,omitempty"` + Return *TypeDesc `yaml:"return,omitempty"` } // AsFunctionOption converts the serializable form of the Overload into a function declaration option. @@ -186,6 +297,18 @@ func (od *Overload) AsFunctionOption(tp types.Provider) (decls.FunctionOpt, erro return decls.Overload(od.ID, args, result), nil } +// NewExtension creates a serializable Extension from a name and version string. +func NewExtension(name string, version uint32) *Extension { + versionString := "latest" + if version < math.MaxUint32 { + versionString = strconv.FormatUint(uint64(version), 10) + } + return &Extension{ + Name: name, + Version: versionString, + } +} + // Extension represents a named and optionally versioned extension library configured in the environment. type Extension struct { // Name is either the LibraryName() or some short-hand simple identifier which is understood by the config-handler. @@ -213,28 +336,37 @@ func (e *Extension) GetVersion() (uint32, error) { return uint32(ver), nil } +// NewLibrarySubset returns an empty library subsetting config which permits all library features. +func NewLibrarySubset() *LibrarySubset { + return &LibrarySubset{} +} + // LibrarySubset indicates a subset of the macros and function supported by a subsettable library. type LibrarySubset struct { + // Disabled indicates whether the library has been disabled, typically only used for + // default-enabled libraries like stdlib. + Disabled bool `yaml:"disabled,omitempty"` + // DisableMacros disables macros for the given library. - DisableMacros bool `yaml:"disable_macros"` + DisableMacros bool `yaml:"disable_macros,omitempty"` // IncludeMacros specifies a set of macro function names to include in the subset. - IncludeMacros []string `yaml:"include_macros"` + IncludeMacros []string `yaml:"include_macros,omitempty"` // ExcludeMacros specifies a set of macro function names to exclude from the subset. // Note: if IncludeMacros is non-empty, then ExcludeFunctions is ignored. - ExcludeMacros []string `yaml:"exclude_macros"` + ExcludeMacros []string `yaml:"exclude_macros,omitempty"` // IncludeFunctions specifies a set of functions to include in the subset. // // Note: the overloads specified in the subset need only specify their ID. // Note: if IncludeFunctions is non-empty, then ExcludeFunctions is ignored. - IncludeFunctions []*Function `yaml:"include_functions"` + IncludeFunctions []*Function `yaml:"include_functions,omitempty"` // ExcludeFunctions specifies the set of functions to exclude from the subset. // // Note: the overloads specified in the subset need only specify their ID. - ExcludeFunctions []*Function `yaml:"exclude_functions"` + ExcludeFunctions []*Function `yaml:"exclude_functions,omitempty"` } // SubsetFunction produces a function declaration which matches the supported subset, or nil @@ -252,6 +384,9 @@ func (lib *LibrarySubset) SubsetFunction(fn *decls.FunctionDecl) (*decls.Functio if lib == nil { return fn, true } + if lib.Disabled { + return nil, false + } if len(lib.IncludeFunctions) != 0 { for _, include := range lib.IncludeFunctions { if include.Name != fn.Name() { @@ -293,7 +428,7 @@ func (lib *LibrarySubset) SubsetMacro(macroFunction string) bool { if lib == nil { return true } - if lib.DisableMacros { + if lib.Disabled || lib.DisableMacros { return false } if len(lib.IncludeMacros) != 0 { @@ -315,11 +450,74 @@ func (lib *LibrarySubset) SubsetMacro(macroFunction string) bool { return true } +// SetDisabled disables or enables the library. +func (lib *LibrarySubset) SetDisabled(value bool) *LibrarySubset { + lib.Disabled = value + return lib +} + +// SetDisableMacros disables the macros for the library. +func (lib *LibrarySubset) SetDisableMacros(value bool) *LibrarySubset { + lib.DisableMacros = value + return lib +} + +// AddIncludedMacros allow-lists one or more macros by function name. +// +// Note, this option will override any excluded macros. +func (lib *LibrarySubset) AddIncludedMacros(macros ...string) *LibrarySubset { + lib.IncludeMacros = append(lib.IncludeMacros, macros...) + return lib +} + +// AddExcludedMacros deny-lists one or more macros by function name. +func (lib *LibrarySubset) AddExcludedMacros(macros ...string) *LibrarySubset { + lib.ExcludeMacros = append(lib.ExcludeMacros, macros...) + return lib +} + +// AddIncludedFunctions allow-lists one or more functions from the subset. +// +// Note, this option will override any excluded functions. +func (lib *LibrarySubset) AddIncludedFunctions(funcs ...*Function) *LibrarySubset { + lib.IncludeFunctions = append(lib.IncludeFunctions, funcs...) + return lib +} + +// AddExcludedFunctions deny-lists one or more functions from the subset. +func (lib *LibrarySubset) AddExcludedFunctions(funcs ...*Function) *LibrarySubset { + lib.ExcludeFunctions = append(lib.ExcludeFunctions, funcs...) + return lib +} + +// NewTypeDesc describes a simple or complex type with parameters. +func NewTypeDesc(typeName string, params ...*TypeDesc) *TypeDesc { + return &TypeDesc{TypeName: typeName, Params: params} +} + +// NewTypeParam describe a type-param type. +func NewTypeParam(paramName string) *TypeDesc { + return &TypeDesc{TypeName: paramName, IsTypeParam: true} +} + // TypeDesc represents the serializable format of a CEL *types.Type value. type TypeDesc struct { TypeName string `yaml:"type_name"` - Params []*TypeDesc `yaml:"params"` - IsTypeParam bool `yaml:"is_type_param"` + Params []*TypeDesc `yaml:"params,omitempty"` + IsTypeParam bool `yaml:"is_type_param,omitempty"` +} + +// String implements the strings.Stringer interface method. +func (td *TypeDesc) String() string { + ps := make([]string, len(td.Params)) + for i, p := range td.Params { + ps[i] = p.String() + } + typeName := td.TypeName + if len(ps) != 0 { + typeName = fmt.Sprintf("%s(%s)", typeName, strings.Join(ps, ",")) + } + return typeName } // AsCELType converts the serializable object to a *types.Type value. @@ -391,3 +589,29 @@ func (td *TypeDesc) AsCELType(tp types.Provider) (*types.Type, error) { return types.NewOpaqueType(td.TypeName, params...), nil } } + +func serializeTypeDesc(t *types.Type) *TypeDesc { + typeName := t.TypeName() + if t.Kind() == types.TypeParamKind { + return NewTypeParam(typeName) + } + if t != types.NullType && t.IsAssignableType(types.NullType) { + if wrapperTypeName, found := wrapperTypes[t.Kind()]; found { + return NewTypeDesc(wrapperTypeName) + } + } + var params []*TypeDesc + for _, p := range t.Parameters() { + params = append(params, serializeTypeDesc(p)) + } + return NewTypeDesc(typeName, params...) +} + +var wrapperTypes = map[types.Kind]string{ + types.BoolKind: "google.protobuf.BoolValue", + types.BytesKind: "google.protobuf.BytesValue", + types.DoubleKind: "google.protobuf.DoubleValue", + types.IntKind: "google.protobuf.Int64Value", + types.StringKind: "google.protobuf.StringValue", + types.UintKind: "google.protobuf.UInt64Value", +} diff --git a/common/env/env_test.go b/common/env/env_test.go index a59347aaa..157d5d35b 100644 --- a/common/env/env_test.go +++ b/common/env/env_test.go @@ -16,19 +16,273 @@ package env import ( "errors" + "fmt" "math" + "os" "reflect" "strings" "testing" + "gopkg.in/yaml.v3" + "github.com/google/cel-go/common/decls" + "github.com/google/cel-go/common/operators" + "github.com/google/cel-go/common/overloads" "github.com/google/cel-go/common/types" ) func TestConfig(t *testing.T) { - conf := NewConfig() - if conf == nil { - t.Fatal("got nil config, wanted non-nil value") + tests := []struct { + name string + want *Config + }{ + { + name: "context_env", + want: NewConfig("context-env"). + SetContainer("google.expr"). + AddImports(NewImport("google.expr.proto3.test.TestAllTypes")). + SetStdLib(NewLibrarySubset(). + AddIncludedMacros("has"). + AddIncludedFunctions([]*Function{ + {Name: operators.Equals}, + {Name: operators.NotEquals}, + {Name: operators.LogicalNot}, + {Name: operators.Less}, + {Name: operators.LessEquals}, + {Name: operators.Greater}, + {Name: operators.GreaterEquals}, + }...)). + AddExtensions(NewExtension("optional", math.MaxUint32), NewExtension("strings", 1)). + SetContextVariable(NewContextVariable("google.expr.proto3.test.TestAllTypes")). + AddFunctions( + NewFunction("coalesce", []*Overload{ + NewOverload("coalesce_wrapped_int", + []*TypeDesc{NewTypeDesc("google.protobuf.Int64Value"), NewTypeDesc("int")}, + NewTypeDesc("int")), + NewOverload("coalesce_wrapped_double", + []*TypeDesc{NewTypeDesc("google.protobuf.DoubleValue"), NewTypeDesc("double")}, + NewTypeDesc("double")), + NewOverload("coalesce_wrapped_uint", + []*TypeDesc{NewTypeDesc("google.protobuf.UInt64Value"), NewTypeDesc("uint")}, + NewTypeDesc("uint")), + }), + ), + }, + { + name: "extended_env", + want: NewConfig("extended-env"). + SetContainer("google.expr"). + AddExtensions( + NewExtension("optional", 2), + NewExtension("math", math.MaxUint32), + ).AddVariables( + NewVariable("msg", NewTypeDesc("google.expr.proto3.test.TestAllTypes")), + ).AddFunctions( + NewFunction("isEmpty", []*Overload{ + NewMemberOverload("wrapper_string_isEmpty", + NewTypeDesc("google.protobuf.StringValue"), nil, + NewTypeDesc("bool")), + NewMemberOverload("list_isEmpty", + NewTypeDesc("list", NewTypeParam("T")), nil, + NewTypeDesc("bool")), + }), + ), + }, + { + name: "subset_env", + want: NewConfig("subset-env"). + SetStdLib(NewLibrarySubset(). + AddExcludedMacros("map", "filter"). + AddExcludedFunctions( + []*Function{ + {Name: operators.Add, Overloads: []*Overload{ + {ID: overloads.AddBytes}, + {ID: overloads.AddList}, + {ID: overloads.AddString}, + }}, + {Name: overloads.Matches}, + {Name: overloads.TypeConvertTimestamp, Overloads: []*Overload{ + {ID: overloads.StringToTimestamp}, + }}, + {Name: overloads.TypeConvertDuration, Overloads: []*Overload{ + {ID: overloads.StringToDuration}, + }}, + }..., + )).AddVariables( + NewVariable("x", NewTypeDesc("int")), + NewVariable("y", NewTypeDesc("double")), + NewVariable("z", NewTypeDesc("uint")), + ), + }, + } + for _, tst := range tests { + tc := tst + t.Run(tc.name, func(t *testing.T) { + fileName := fmt.Sprintf("testdata/%s.yaml", tc.name) + data, err := os.ReadFile(fileName) + if err != nil { + t.Fatalf("os.ReadFile(%q) failed: %v", fileName, err) + } + got := unmarshalYAML(t, data) + if got.Container != tc.want.Container { + t.Errorf("Container got %s, wanted %s", got.Container, tc.want.Container) + } + if !reflect.DeepEqual(got.Imports, tc.want.Imports) { + t.Errorf("Imports got %v, wanted %v", got.Imports, tc.want.Imports) + } + if !reflect.DeepEqual(got.StdLib, tc.want.StdLib) { + t.Errorf("StdLib got %v, wanted %v", got.StdLib, tc.want.StdLib) + } + if !reflect.DeepEqual(got.ContextVariable, tc.want.ContextVariable) { + t.Errorf("ContextVariable got %v, wanted %v", got.ContextVariable, tc.want.ContextVariable) + } + if len(got.Variables) != len(tc.want.Variables) { + t.Errorf("Variables count got %d, wanted %d", len(got.Variables), len(tc.want.Variables)) + } else { + for i, v := range got.Variables { + wv := tc.want.Variables[i] + if !reflect.DeepEqual(v, wv) { + t.Errorf("Variables[%d] not equal, got %v, wanted %v", i, v, wv) + } + } + } + if len(got.Functions) != len(tc.want.Functions) { + t.Errorf("Functions count got %d, wanted %d", len(got.Functions), len(tc.want.Functions)) + } else { + for i, f := range got.Functions { + wf := tc.want.Functions[i] + if f.Name != wf.Name { + t.Errorf("Functions[%d] not equal, got %v, wanted %v", i, f.Name, wf.Name) + } + if len(f.Overloads) != len(wf.Overloads) { + t.Errorf("Function %s got overload count: %d, wanted %d", f.Name, len(f.Overloads), len(wf.Overloads)) + } + for j, o := range f.Overloads { + wo := wf.Overloads[j] + if !reflect.DeepEqual(o, wo) { + t.Errorf("Overload[%d] got %v, wanted %v", j, o, wo) + } + } + } + } + }) + } +} + +func TestNewImport(t *testing.T) { + imp := NewImport("qualified.type.name") + if imp.Name != "qualified.type.name" { + t.Errorf("NewImport() got name: %s, wanted %s", imp.Name, "qualified.type.name") + } +} + +func TestNewContextVariable(t *testing.T) { + ctx := NewContextVariable("qualified.type.name") + if ctx.TypeName != "qualified.type.name" { + t.Errorf("NewContextVariable() got name: %s, wanted %s", ctx.TypeName, "qualified.type.name") + } +} + +func TestConfigAddVariableDecls(t *testing.T) { + tests := []struct { + name string + in *decls.VariableDecl + out *Variable + }{ + { + name: "nil var decl", + }, + { + name: "simple var decl", + in: decls.NewVariable("var", types.StringType), + out: NewVariable("var", NewTypeDesc("string")), + }, + { + name: "parameterized var decl", + in: decls.NewVariable("var", types.NewListType(types.NewTypeParamType("T"))), + out: NewVariable("var", NewTypeDesc("list", NewTypeParam("T"))), + }, + { + name: "opaque var decl", + in: decls.NewVariable("var", types.NewOpaqueType("bitvector")), + out: NewVariable("var", NewTypeDesc("bitvector")), + }, + { + name: "proto var decl", + in: decls.NewVariable("var", types.NewObjectType("google.type.Expr")), + out: NewVariable("var", NewTypeDesc("google.type.Expr")), + }, + } + for _, tst := range tests { + tc := tst + t.Run(tc.name, func(t *testing.T) { + conf := NewConfig(tc.name).AddVariableDecls(tc.in) + if len(conf.Variables) != 1 { + t.Fatalf("AddVariableDecls() did not add declaration to conf: %v", conf) + } + if !reflect.DeepEqual(conf.Variables[0], tc.out) { + t.Errorf("AddVariableDecls() added %v, wanted %v", conf.Variables, tc.out) + } + }) + } +} + +func TestConfigAddVariableDeclsEmpty(t *testing.T) { + if len(NewConfig("").AddVariables().Variables) != 0 { + t.Error("AddVariables() with no args failed") + } +} + +func TestConfigAddFunctionDecls(t *testing.T) { + tests := []struct { + name string + in *decls.FunctionDecl + out *Function + }{ + { + name: "nil function decl", + }, + { + name: "global function decl", + in: mustNewFunction(t, "size", + decls.Overload("size_string", []*types.Type{types.StringType}, types.IntType), + ), + out: NewFunction("size", []*Overload{ + NewOverload("size_string", []*TypeDesc{NewTypeDesc("string")}, NewTypeDesc("int")), + }), + }, + { + name: "global function decl - nullable arg", + in: mustNewFunction(t, "size", + decls.Overload("size_wrapper_string", []*types.Type{types.NewNullableType(types.StringType)}, types.IntType), + ), + out: NewFunction("size", []*Overload{ + NewOverload("size_wrapper_string", []*TypeDesc{NewTypeDesc("google.protobuf.StringValue")}, NewTypeDesc("int")), + }), + }, + { + name: "member function decl - nullable arg", + in: mustNewFunction(t, "size", + decls.MemberOverload("list_size", []*types.Type{types.NewListType(types.NewTypeParamType("T"))}, types.IntType), + decls.MemberOverload("string_size", []*types.Type{types.StringType}, types.IntType), + ), + out: NewFunction("size", []*Overload{ + NewMemberOverload("list_size", NewTypeDesc("list", NewTypeParam("T")), []*TypeDesc{}, NewTypeDesc("int")), + NewMemberOverload("string_size", NewTypeDesc("string"), []*TypeDesc{}, NewTypeDesc("int")), + }), + }, + } + for _, tst := range tests { + tc := tst + t.Run(tc.name, func(t *testing.T) { + conf := NewConfig(tc.name).AddFunctionDecls(tc.in) + if len(conf.Functions) != 1 { + t.Fatalf("AddFunctionDecls() did not add declaration to conf: %v", conf) + } + if !reflect.DeepEqual(conf.Functions[0], tc.out) { + t.Errorf("AddFunctionDecls() added %v, wanted %v", conf.Functions, tc.out) + } + }) } } @@ -111,17 +365,14 @@ func TestVariableAsCELVariable(t *testing.T) { }, { name: "int type", - v: &Variable{ - Name: "int_var", - TypeDesc: &TypeDesc{TypeName: "int"}, - }, + v: NewVariable("int_var", NewTypeDesc("int")), want: decls.NewVariable("int_var", types.IntType), }, { name: "uint type", v: &Variable{ Name: "uint_var", - TypeDesc: &TypeDesc{TypeName: "uint"}, + TypeDesc: NewTypeDesc("uint"), }, want: decls.NewVariable("uint_var", types.UintType), }, @@ -129,7 +380,7 @@ func TestVariableAsCELVariable(t *testing.T) { name: "dyn type", v: &Variable{ Name: "dyn_var", - TypeDesc: &TypeDesc{TypeName: "dyn"}, + TypeDesc: NewTypeDesc("dyn"), }, want: decls.NewVariable("dyn_var", types.DynType), }, @@ -137,7 +388,7 @@ func TestVariableAsCELVariable(t *testing.T) { name: "list type", v: &Variable{ Name: "list_var", - TypeDesc: &TypeDesc{TypeName: "list", Params: []*TypeDesc{{TypeName: "T", IsTypeParam: true}}}, + TypeDesc: NewTypeDesc("list", NewTypeParam("T")), }, want: decls.NewVariable("list_var", types.NewListType(types.NewTypeParamType("T"))), }, @@ -148,9 +399,8 @@ func TestVariableAsCELVariable(t *testing.T) { TypeDesc: &TypeDesc{ TypeName: "map", Params: []*TypeDesc{ - {TypeName: "string"}, - {TypeName: "optional_type", - Params: []*TypeDesc{{TypeName: "T", IsTypeParam: true}}}, + NewTypeDesc("string"), + NewTypeDesc("optional_type", NewTypeParam("T")), }, }, }, @@ -160,13 +410,8 @@ func TestVariableAsCELVariable(t *testing.T) { { name: "set type", v: &Variable{ - Name: "set_var", - TypeDesc: &TypeDesc{ - TypeName: "set", - Params: []*TypeDesc{ - {TypeName: "string"}, - }, - }, + Name: "set_var", + TypeDesc: NewTypeDesc("set", NewTypeDesc("string")), }, want: decls.NewVariable("set_var", types.NewOpaqueType("set", types.StringType)), }, @@ -174,8 +419,8 @@ func TestVariableAsCELVariable(t *testing.T) { name: "string type - nested type precedence", v: &Variable{ Name: "hello", - TypeDesc: &TypeDesc{TypeName: "string"}, - Type: &TypeDesc{TypeName: "int"}, + TypeDesc: NewTypeDesc("string"), + Type: NewTypeDesc("int"), }, want: decls.NewVariable("hello", types.StringType), }, @@ -183,7 +428,7 @@ func TestVariableAsCELVariable(t *testing.T) { name: "wrapper type variable", v: &Variable{ Name: "msg", - TypeDesc: &TypeDesc{TypeName: "google.protobuf.StringValue"}, + TypeDesc: NewTypeDesc("google.protobuf.StringValue"), }, want: decls.NewVariable("msg", types.NewNullableType(types.StringType)), }, @@ -215,6 +460,22 @@ func TestVariableAsCELVariable(t *testing.T) { } } +func TestTypeDescString(t *testing.T) { + tests := []struct { + desc *TypeDesc + want string + }{ + {desc: NewTypeDesc("string"), want: "string"}, + {desc: NewTypeDesc("list", NewTypeParam("T")), want: "list(T)"}, + {desc: NewTypeDesc("map", NewTypeDesc("string"), NewTypeParam("T")), want: "map(string,T)"}, + } + for _, tc := range tests { + if tc.desc.String() != tc.want { + t.Errorf("String() got %s, wanted %s", tc.desc.String(), tc.want) + } + } +} + func TestFunctionAsCELFunction(t *testing.T) { tests := []struct { name string @@ -233,47 +494,33 @@ func TestFunctionAsCELFunction(t *testing.T) { }, { name: "no overloads", - f: &Function{Name: "no_overloads"}, + f: NewFunction("no_overloads", []*Overload{}), want: errors.New("must declare an overload"), }, { name: "nil overload", - f: &Function{Name: "no_overloads", Overloads: []*Overload{nil}}, + f: NewFunction("no_overloads", []*Overload{nil}), want: errors.New("nil Overload"), }, { name: "no return type", - f: &Function{Name: "size", - Overloads: []*Overload{ - {ID: "size_string", - Args: []*TypeDesc{{TypeName: "string"}}, - }, - }, - }, + f: NewFunction("size", []*Overload{ + NewOverload("size_string", []*TypeDesc{NewTypeDesc("string")}, nil), + }), want: errors.New("missing return type"), }, { name: "bad return type", - f: &Function{Name: "size", - Overloads: []*Overload{ - {ID: "size_string", - Args: []*TypeDesc{{TypeName: "string"}}, - Return: &TypeDesc{}, - }, - }, - }, + f: NewFunction("size", []*Overload{ + NewOverload("size_string", []*TypeDesc{NewTypeDesc("string")}, NewTypeDesc("")), + }), want: errors.New("invalid type"), }, { name: "bad arg type", - f: &Function{Name: "size", - Overloads: []*Overload{ - {ID: "size_string", - Args: []*TypeDesc{{}}, - Return: &TypeDesc{}, - }, - }, - }, + f: NewFunction("size", []*Overload{ + NewOverload("size_string", []*TypeDesc{NewTypeDesc("")}, NewTypeDesc("")), + }), want: errors.New("invalid type"), }, { @@ -353,13 +600,13 @@ func TestTypeDescAsCELTypeErrors(t *testing.T) { want: errors.New("invalid type"), }, { - name: "invalid optional", - t: &TypeDesc{TypeName: "optional"}, + name: "invalid optional_type", + t: &TypeDesc{TypeName: "optional_type"}, want: errors.New("unexpected param count"), }, { name: "invalid optional param type", - t: &TypeDesc{TypeName: "optional", Params: []*TypeDesc{{}}}, + t: &TypeDesc{TypeName: "optional_type", Params: []*TypeDesc{{}}}, want: errors.New("invalid type"), }, { @@ -440,39 +687,39 @@ func TestSubsetFunction(t *testing.T) { }, { name: "empty, included", - lib: &LibrarySubset{}, + lib: NewLibrarySubset(), orig: mustNewFunction(t, "size", decls.Overload("size_string", []*types.Type{types.StringType}, types.IntType)), subset: mustNewFunction(t, "size", decls.Overload("size_string", []*types.Type{types.StringType}, types.IntType)), included: true, }, + { + name: "empty, disabled", + lib: NewLibrarySubset().SetDisabled(true), + orig: mustNewFunction(t, "size", decls.Overload("size_string", []*types.Type{types.StringType}, types.IntType)), + included: false, + }, { name: "lib, not included allow-list", - lib: &LibrarySubset{ - IncludeFunctions: []*Function{ - {Name: "int"}, - }, - }, + lib: NewLibrarySubset().AddIncludedFunctions([]*Function{ + {Name: "int"}, + }...), orig: mustNewFunction(t, "size", decls.Overload("size_string", []*types.Type{types.StringType}, types.IntType)), included: false, }, { name: "lib, included whole function", - lib: &LibrarySubset{ - IncludeFunctions: []*Function{ - {Name: "size"}, - }, - }, + lib: NewLibrarySubset().AddIncludedFunctions([]*Function{ + {Name: "size"}, + }...), orig: mustNewFunction(t, "size", decls.Overload("size_string", []*types.Type{types.StringType}, types.IntType)), subset: mustNewFunction(t, "size", decls.Overload("size_string", []*types.Type{types.StringType}, types.IntType)), included: true, }, { name: "lib, included overload subset", - lib: &LibrarySubset{ - IncludeFunctions: []*Function{ - {Name: "size", Overloads: []*Overload{{ID: "size_string"}}}, - }, - }, + lib: NewLibrarySubset().AddIncludedFunctions([]*Function{ + {Name: "size", Overloads: []*Overload{{ID: "size_string"}}}, + }...), orig: mustNewFunction(t, "size", decls.Overload("size_string", []*types.Type{types.StringType}, types.IntType), decls.Overload("size_list", []*types.Type{types.NewListType(types.NewTypeParamType("T"))}, types.IntType), @@ -482,32 +729,26 @@ func TestSubsetFunction(t *testing.T) { }, { name: "lib, included deny-list", - lib: &LibrarySubset{ - ExcludeFunctions: []*Function{ - {Name: "int"}, - }, - }, + lib: NewLibrarySubset().AddExcludedFunctions([]*Function{ + {Name: "int"}, + }...), orig: mustNewFunction(t, "size", decls.Overload("size_string", []*types.Type{types.StringType}, types.IntType)), subset: mustNewFunction(t, "size", decls.Overload("size_string", []*types.Type{types.StringType}, types.IntType)), included: true, }, { name: "lib, excluded whole function", - lib: &LibrarySubset{ - ExcludeFunctions: []*Function{ - {Name: "size"}, - }, - }, + lib: NewLibrarySubset().AddExcludedFunctions([]*Function{ + {Name: "size"}, + }...), orig: mustNewFunction(t, "size", decls.Overload("size_string", []*types.Type{types.StringType}, types.IntType)), included: false, }, { name: "lib, excluded partial function", - lib: &LibrarySubset{ - ExcludeFunctions: []*Function{ - {Name: "size", Overloads: []*Overload{{ID: "size_list"}}}, - }, - }, + lib: NewLibrarySubset().AddExcludedFunctions([]*Function{ + {Name: "size", Overloads: []*Overload{{ID: "size_list"}}}, + }...), orig: mustNewFunction(t, "size", decls.Overload("size_string", []*types.Type{types.StringType}, types.IntType), decls.Overload("size_list", []*types.Type{types.NewListType(types.NewTypeParamType("T"))}, types.IntType), @@ -545,45 +786,43 @@ func TestSubsetMacro(t *testing.T) { }, { name: "empty, included", - lib: &LibrarySubset{}, + lib: NewLibrarySubset(), macroName: "has", included: true, }, + { + name: "empty, disabled", + lib: NewLibrarySubset().SetDisabled(true), + macroName: "has", + included: false, + }, { name: "empty, included", - lib: &LibrarySubset{DisableMacros: true}, + lib: NewLibrarySubset().SetDisableMacros(true), macroName: "has", included: false, }, { - name: "lib, not included allow-list", - lib: &LibrarySubset{ - IncludeMacros: []string{"exists"}, - }, + name: "lib, not included allow-list", + lib: NewLibrarySubset().AddIncludedMacros("exists"), macroName: "has", included: false, }, { - name: "lib, included allow-list", - lib: &LibrarySubset{ - IncludeMacros: []string{"exists"}, - }, + name: "lib, included allow-list", + lib: NewLibrarySubset().AddIncludedMacros("exists"), macroName: "exists", included: true, }, { - name: "lib, not included deny-list", - lib: &LibrarySubset{ - ExcludeMacros: []string{"exists"}, - }, + name: "lib, not included deny-list", + lib: NewLibrarySubset().AddExcludedMacros("exists"), macroName: "exists", included: false, }, { - name: "lib, included deny-list", - lib: &LibrarySubset{ - ExcludeMacros: []string{"exists"}, - }, + name: "lib, included deny-list", + lib: NewLibrarySubset().AddExcludedMacros("exists"), macroName: "has", included: true, }, @@ -599,6 +838,34 @@ func TestSubsetMacro(t *testing.T) { } } +func TestNewExtension(t *testing.T) { + tests := []struct { + name string + version uint32 + want *Extension + }{ + { + name: "strings", + version: math.MaxUint32, + want: &Extension{Name: "strings", Version: "latest"}, + }, + { + name: "bindings", + version: 1, + want: &Extension{Name: "bindings", Version: "1"}, + }, + } + for _, tst := range tests { + tc := tst + t.Run(tc.name, func(t *testing.T) { + got := NewExtension(tc.name, tc.version) + if !reflect.DeepEqual(got, tc.want) { + t.Errorf("NewExtension() got %v, wanted %v", got, tc.want) + } + }) + } +} + func TestExtensionGetVersion(t *testing.T) { tests := []struct { name string @@ -699,3 +966,21 @@ func assertFuncEquals(t *testing.T, got, want *decls.FunctionDecl) { } } } + +func unmarshalYAML(t *testing.T, data []byte) *Config { + t.Helper() + config := &Config{} + if err := yaml.Unmarshal(data, config); err != nil { + t.Fatalf("yaml.Unmarshal(%q) failed: %v", string(data), err) + } + return config +} + +func marshalYAML(t *testing.T, config *Config) []byte { + t.Helper() + data, err := yaml.Marshal(config) + if err != nil { + t.Fatalf("yaml.Marshal(%q) failed: %v", string(data), err) + } + return data +} diff --git a/common/env/testdata/context_env.yaml b/common/env/testdata/context_env.yaml new file mode 100644 index 000000000..348d2d723 --- /dev/null +++ b/common/env/testdata/context_env.yaml @@ -0,0 +1,57 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: "context-env" +container: "google.expr" +imports: + - name: "google.expr.proto3.test.TestAllTypes" +stdlib: + include_macros: + - has + include_functions: + - name: "_==_" + - name: "_!=_" + - name: "!_" + - name: "_<_" + - name: "_<=_" + - name: "_>_" + - name: "_>=_" +extensions: + - name: "optional" + version: "latest" + - name: "strings" + version: 1 +context_variable: + type_name: "google.expr.proto3.test.TestAllTypes" +functions: + - name: "coalesce" + overloads: + - id: "coalesce_wrapped_int" + args: + - type_name: "google.protobuf.Int64Value" + - type_name: "int" + return: + type_name: "int" + - id: "coalesce_wrapped_double" + args: + - type_name: "google.protobuf.DoubleValue" + - type_name: "double" + return: + type_name: "double" + - id: "coalesce_wrapped_uint" + args: + - type_name: "google.protobuf.UInt64Value" + - type_name: "uint" + return: + type_name: "uint" diff --git a/common/env/testdata/extended_env.yaml b/common/env/testdata/extended_env.yaml new file mode 100644 index 000000000..041002e75 --- /dev/null +++ b/common/env/testdata/extended_env.yaml @@ -0,0 +1,40 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: "extended-env" +container: "google.expr" +extensions: + - name: "optional" + version: "2" + - name: "math" + version: "latest" +variables: + - name: "msg" + type_name: "google.expr.proto3.test.TestAllTypes" +functions: + - name: "isEmpty" + overloads: + - id: "wrapper_string_isEmpty" + target: + type_name: "google.protobuf.StringValue" + return: + type_name: "bool" + - id: "list_isEmpty" + target: + type_name: "list" + params: + - type_name: "T" + is_type_param: true + return: + type_name: "bool" diff --git a/common/env/testdata/subset_env.yaml b/common/env/testdata/subset_env.yaml new file mode 100644 index 000000000..44437e718 --- /dev/null +++ b/common/env/testdata/subset_env.yaml @@ -0,0 +1,39 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: "subset-env" +stdlib: + exclude_macros: + - map + - filter + exclude_functions: + - name: "_+_" + overloads: + - id: add_bytes + - id: add_list + - id: add_string + - name: "matches" + - name: "timestamp" + overloads: + - id: "string_to_timestamp" + - name: "duration" + overloads: + - id: "string_to_duration" +variables: + - name: "x" + type_name: "int" + - name: "y" + type_name: "double" + - name: "z" + type_name: "uint" diff --git a/go.mod b/go.mod index 7b976324f..ae23e9ee6 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( golang.org/x/text v0.16.0 google.golang.org/genproto/googleapis/api v0.0.0-20240826202546-f6391c0de4c7 google.golang.org/protobuf v1.34.2 + gopkg.in/yaml.v3 v3.0.1 ) require ( diff --git a/go.sum b/go.sum index fbd276c1e..b518e1a53 100644 --- a/go.sum +++ b/go.sum @@ -23,6 +23,9 @@ google.golang.org/genproto/googleapis/rpc v0.0.0-20240826202546-f6391c0de4c7 h1: google.golang.org/genproto/googleapis/rpc v0.0.0-20240826202546-f6391c0de4c7/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU= google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/policy/config.go b/policy/config.go index 50f2852a2..0501709b6 100644 --- a/policy/config.go +++ b/policy/config.go @@ -43,9 +43,18 @@ func (c *Config) AsEnvOptions(provider types.Provider) ([]cel.EnvOption, error) envOpts := []cel.EnvOption{} // Configure the standard lib subset. if c.StdLib != nil { - envOpts = append(envOpts, func(e *cel.Env) (*cel.Env, error) { - return cel.NewCustomEnv(cel.StdLib(cel.StdLibSubset(c.StdLib))) - }) + if c.StdLib.Disabled { + envOpts = append(envOpts, func(e *cel.Env) (*cel.Env, error) { + if !e.HasLibrary("cel.lib.std") { + return e, nil + } + return cel.NewCustomEnv() + }) + } else { + envOpts = append(envOpts, func(e *cel.Env) (*cel.Env, error) { + return cel.NewCustomEnv(cel.StdLib(cel.StdLibSubset(c.StdLib))) + }) + } } // Configure the container diff --git a/vendor/gopkg.in/yaml.v3/LICENSE b/vendor/gopkg.in/yaml.v3/LICENSE new file mode 100644 index 000000000..2683e4bb1 --- /dev/null +++ b/vendor/gopkg.in/yaml.v3/LICENSE @@ -0,0 +1,50 @@ + +This project is covered by two different licenses: MIT and Apache. + +#### MIT License #### + +The following files were ported to Go from C files of libyaml, and thus +are still covered by their original MIT license, with the additional +copyright staring in 2011 when the project was ported over: + + apic.go emitterc.go parserc.go readerc.go scannerc.go + writerc.go yamlh.go yamlprivateh.go + +Copyright (c) 2006-2010 Kirill Simonov +Copyright (c) 2006-2011 Kirill Simonov + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +### Apache License ### + +All the remaining project files are covered by the Apache license: + +Copyright (c) 2011-2019 Canonical Ltd + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/vendor/gopkg.in/yaml.v3/NOTICE b/vendor/gopkg.in/yaml.v3/NOTICE new file mode 100644 index 000000000..866d74a7a --- /dev/null +++ b/vendor/gopkg.in/yaml.v3/NOTICE @@ -0,0 +1,13 @@ +Copyright 2011-2016 Canonical Ltd. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/vendor/gopkg.in/yaml.v3/README.md b/vendor/gopkg.in/yaml.v3/README.md new file mode 100644 index 000000000..08eb1babd --- /dev/null +++ b/vendor/gopkg.in/yaml.v3/README.md @@ -0,0 +1,150 @@ +# YAML support for the Go language + +Introduction +------------ + +The yaml package enables Go programs to comfortably encode and decode YAML +values. It was developed within [Canonical](https://www.canonical.com) as +part of the [juju](https://juju.ubuntu.com) project, and is based on a +pure Go port of the well-known [libyaml](http://pyyaml.org/wiki/LibYAML) +C library to parse and generate YAML data quickly and reliably. + +Compatibility +------------- + +The yaml package supports most of YAML 1.2, but preserves some behavior +from 1.1 for backwards compatibility. + +Specifically, as of v3 of the yaml package: + + - YAML 1.1 bools (_yes/no, on/off_) are supported as long as they are being + decoded into a typed bool value. Otherwise they behave as a string. Booleans + in YAML 1.2 are _true/false_ only. + - Octals encode and decode as _0777_ per YAML 1.1, rather than _0o777_ + as specified in YAML 1.2, because most parsers still use the old format. + Octals in the _0o777_ format are supported though, so new files work. + - Does not support base-60 floats. These are gone from YAML 1.2, and were + actually never supported by this package as it's clearly a poor choice. + +and offers backwards +compatibility with YAML 1.1 in some cases. +1.2, including support for +anchors, tags, map merging, etc. Multi-document unmarshalling is not yet +implemented, and base-60 floats from YAML 1.1 are purposefully not +supported since they're a poor design and are gone in YAML 1.2. + +Installation and usage +---------------------- + +The import path for the package is *gopkg.in/yaml.v3*. + +To install it, run: + + go get gopkg.in/yaml.v3 + +API documentation +----------------- + +If opened in a browser, the import path itself leads to the API documentation: + + - [https://gopkg.in/yaml.v3](https://gopkg.in/yaml.v3) + +API stability +------------- + +The package API for yaml v3 will remain stable as described in [gopkg.in](https://gopkg.in). + + +License +------- + +The yaml package is licensed under the MIT and Apache License 2.0 licenses. +Please see the LICENSE file for details. + + +Example +------- + +```Go +package main + +import ( + "fmt" + "log" + + "gopkg.in/yaml.v3" +) + +var data = ` +a: Easy! +b: + c: 2 + d: [3, 4] +` + +// Note: struct fields must be public in order for unmarshal to +// correctly populate the data. +type T struct { + A string + B struct { + RenamedC int `yaml:"c"` + D []int `yaml:",flow"` + } +} + +func main() { + t := T{} + + err := yaml.Unmarshal([]byte(data), &t) + if err != nil { + log.Fatalf("error: %v", err) + } + fmt.Printf("--- t:\n%v\n\n", t) + + d, err := yaml.Marshal(&t) + if err != nil { + log.Fatalf("error: %v", err) + } + fmt.Printf("--- t dump:\n%s\n\n", string(d)) + + m := make(map[interface{}]interface{}) + + err = yaml.Unmarshal([]byte(data), &m) + if err != nil { + log.Fatalf("error: %v", err) + } + fmt.Printf("--- m:\n%v\n\n", m) + + d, err = yaml.Marshal(&m) + if err != nil { + log.Fatalf("error: %v", err) + } + fmt.Printf("--- m dump:\n%s\n\n", string(d)) +} +``` + +This example will generate the following output: + +``` +--- t: +{Easy! {2 [3 4]}} + +--- t dump: +a: Easy! +b: + c: 2 + d: [3, 4] + + +--- m: +map[a:Easy! b:map[c:2 d:[3 4]]] + +--- m dump: +a: Easy! +b: + c: 2 + d: + - 3 + - 4 +``` + diff --git a/vendor/gopkg.in/yaml.v3/apic.go b/vendor/gopkg.in/yaml.v3/apic.go new file mode 100644 index 000000000..ae7d049f1 --- /dev/null +++ b/vendor/gopkg.in/yaml.v3/apic.go @@ -0,0 +1,747 @@ +// +// Copyright (c) 2011-2019 Canonical Ltd +// Copyright (c) 2006-2010 Kirill Simonov +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of +// this software and associated documentation files (the "Software"), to deal in +// the Software without restriction, including without limitation the rights to +// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +// of the Software, and to permit persons to whom the Software is furnished to do +// so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package yaml + +import ( + "io" +) + +func yaml_insert_token(parser *yaml_parser_t, pos int, token *yaml_token_t) { + //fmt.Println("yaml_insert_token", "pos:", pos, "typ:", token.typ, "head:", parser.tokens_head, "len:", len(parser.tokens)) + + // Check if we can move the queue at the beginning of the buffer. + if parser.tokens_head > 0 && len(parser.tokens) == cap(parser.tokens) { + if parser.tokens_head != len(parser.tokens) { + copy(parser.tokens, parser.tokens[parser.tokens_head:]) + } + parser.tokens = parser.tokens[:len(parser.tokens)-parser.tokens_head] + parser.tokens_head = 0 + } + parser.tokens = append(parser.tokens, *token) + if pos < 0 { + return + } + copy(parser.tokens[parser.tokens_head+pos+1:], parser.tokens[parser.tokens_head+pos:]) + parser.tokens[parser.tokens_head+pos] = *token +} + +// Create a new parser object. +func yaml_parser_initialize(parser *yaml_parser_t) bool { + *parser = yaml_parser_t{ + raw_buffer: make([]byte, 0, input_raw_buffer_size), + buffer: make([]byte, 0, input_buffer_size), + } + return true +} + +// Destroy a parser object. +func yaml_parser_delete(parser *yaml_parser_t) { + *parser = yaml_parser_t{} +} + +// String read handler. +func yaml_string_read_handler(parser *yaml_parser_t, buffer []byte) (n int, err error) { + if parser.input_pos == len(parser.input) { + return 0, io.EOF + } + n = copy(buffer, parser.input[parser.input_pos:]) + parser.input_pos += n + return n, nil +} + +// Reader read handler. +func yaml_reader_read_handler(parser *yaml_parser_t, buffer []byte) (n int, err error) { + return parser.input_reader.Read(buffer) +} + +// Set a string input. +func yaml_parser_set_input_string(parser *yaml_parser_t, input []byte) { + if parser.read_handler != nil { + panic("must set the input source only once") + } + parser.read_handler = yaml_string_read_handler + parser.input = input + parser.input_pos = 0 +} + +// Set a file input. +func yaml_parser_set_input_reader(parser *yaml_parser_t, r io.Reader) { + if parser.read_handler != nil { + panic("must set the input source only once") + } + parser.read_handler = yaml_reader_read_handler + parser.input_reader = r +} + +// Set the source encoding. +func yaml_parser_set_encoding(parser *yaml_parser_t, encoding yaml_encoding_t) { + if parser.encoding != yaml_ANY_ENCODING { + panic("must set the encoding only once") + } + parser.encoding = encoding +} + +// Create a new emitter object. +func yaml_emitter_initialize(emitter *yaml_emitter_t) { + *emitter = yaml_emitter_t{ + buffer: make([]byte, output_buffer_size), + raw_buffer: make([]byte, 0, output_raw_buffer_size), + states: make([]yaml_emitter_state_t, 0, initial_stack_size), + events: make([]yaml_event_t, 0, initial_queue_size), + best_width: -1, + } +} + +// Destroy an emitter object. +func yaml_emitter_delete(emitter *yaml_emitter_t) { + *emitter = yaml_emitter_t{} +} + +// String write handler. +func yaml_string_write_handler(emitter *yaml_emitter_t, buffer []byte) error { + *emitter.output_buffer = append(*emitter.output_buffer, buffer...) + return nil +} + +// yaml_writer_write_handler uses emitter.output_writer to write the +// emitted text. +func yaml_writer_write_handler(emitter *yaml_emitter_t, buffer []byte) error { + _, err := emitter.output_writer.Write(buffer) + return err +} + +// Set a string output. +func yaml_emitter_set_output_string(emitter *yaml_emitter_t, output_buffer *[]byte) { + if emitter.write_handler != nil { + panic("must set the output target only once") + } + emitter.write_handler = yaml_string_write_handler + emitter.output_buffer = output_buffer +} + +// Set a file output. +func yaml_emitter_set_output_writer(emitter *yaml_emitter_t, w io.Writer) { + if emitter.write_handler != nil { + panic("must set the output target only once") + } + emitter.write_handler = yaml_writer_write_handler + emitter.output_writer = w +} + +// Set the output encoding. +func yaml_emitter_set_encoding(emitter *yaml_emitter_t, encoding yaml_encoding_t) { + if emitter.encoding != yaml_ANY_ENCODING { + panic("must set the output encoding only once") + } + emitter.encoding = encoding +} + +// Set the canonical output style. +func yaml_emitter_set_canonical(emitter *yaml_emitter_t, canonical bool) { + emitter.canonical = canonical +} + +// Set the indentation increment. +func yaml_emitter_set_indent(emitter *yaml_emitter_t, indent int) { + if indent < 2 || indent > 9 { + indent = 2 + } + emitter.best_indent = indent +} + +// Set the preferred line width. +func yaml_emitter_set_width(emitter *yaml_emitter_t, width int) { + if width < 0 { + width = -1 + } + emitter.best_width = width +} + +// Set if unescaped non-ASCII characters are allowed. +func yaml_emitter_set_unicode(emitter *yaml_emitter_t, unicode bool) { + emitter.unicode = unicode +} + +// Set the preferred line break character. +func yaml_emitter_set_break(emitter *yaml_emitter_t, line_break yaml_break_t) { + emitter.line_break = line_break +} + +///* +// * Destroy a token object. +// */ +// +//YAML_DECLARE(void) +//yaml_token_delete(yaml_token_t *token) +//{ +// assert(token); // Non-NULL token object expected. +// +// switch (token.type) +// { +// case YAML_TAG_DIRECTIVE_TOKEN: +// yaml_free(token.data.tag_directive.handle); +// yaml_free(token.data.tag_directive.prefix); +// break; +// +// case YAML_ALIAS_TOKEN: +// yaml_free(token.data.alias.value); +// break; +// +// case YAML_ANCHOR_TOKEN: +// yaml_free(token.data.anchor.value); +// break; +// +// case YAML_TAG_TOKEN: +// yaml_free(token.data.tag.handle); +// yaml_free(token.data.tag.suffix); +// break; +// +// case YAML_SCALAR_TOKEN: +// yaml_free(token.data.scalar.value); +// break; +// +// default: +// break; +// } +// +// memset(token, 0, sizeof(yaml_token_t)); +//} +// +///* +// * Check if a string is a valid UTF-8 sequence. +// * +// * Check 'reader.c' for more details on UTF-8 encoding. +// */ +// +//static int +//yaml_check_utf8(yaml_char_t *start, size_t length) +//{ +// yaml_char_t *end = start+length; +// yaml_char_t *pointer = start; +// +// while (pointer < end) { +// unsigned char octet; +// unsigned int width; +// unsigned int value; +// size_t k; +// +// octet = pointer[0]; +// width = (octet & 0x80) == 0x00 ? 1 : +// (octet & 0xE0) == 0xC0 ? 2 : +// (octet & 0xF0) == 0xE0 ? 3 : +// (octet & 0xF8) == 0xF0 ? 4 : 0; +// value = (octet & 0x80) == 0x00 ? octet & 0x7F : +// (octet & 0xE0) == 0xC0 ? octet & 0x1F : +// (octet & 0xF0) == 0xE0 ? octet & 0x0F : +// (octet & 0xF8) == 0xF0 ? octet & 0x07 : 0; +// if (!width) return 0; +// if (pointer+width > end) return 0; +// for (k = 1; k < width; k ++) { +// octet = pointer[k]; +// if ((octet & 0xC0) != 0x80) return 0; +// value = (value << 6) + (octet & 0x3F); +// } +// if (!((width == 1) || +// (width == 2 && value >= 0x80) || +// (width == 3 && value >= 0x800) || +// (width == 4 && value >= 0x10000))) return 0; +// +// pointer += width; +// } +// +// return 1; +//} +// + +// Create STREAM-START. +func yaml_stream_start_event_initialize(event *yaml_event_t, encoding yaml_encoding_t) { + *event = yaml_event_t{ + typ: yaml_STREAM_START_EVENT, + encoding: encoding, + } +} + +// Create STREAM-END. +func yaml_stream_end_event_initialize(event *yaml_event_t) { + *event = yaml_event_t{ + typ: yaml_STREAM_END_EVENT, + } +} + +// Create DOCUMENT-START. +func yaml_document_start_event_initialize( + event *yaml_event_t, + version_directive *yaml_version_directive_t, + tag_directives []yaml_tag_directive_t, + implicit bool, +) { + *event = yaml_event_t{ + typ: yaml_DOCUMENT_START_EVENT, + version_directive: version_directive, + tag_directives: tag_directives, + implicit: implicit, + } +} + +// Create DOCUMENT-END. +func yaml_document_end_event_initialize(event *yaml_event_t, implicit bool) { + *event = yaml_event_t{ + typ: yaml_DOCUMENT_END_EVENT, + implicit: implicit, + } +} + +// Create ALIAS. +func yaml_alias_event_initialize(event *yaml_event_t, anchor []byte) bool { + *event = yaml_event_t{ + typ: yaml_ALIAS_EVENT, + anchor: anchor, + } + return true +} + +// Create SCALAR. +func yaml_scalar_event_initialize(event *yaml_event_t, anchor, tag, value []byte, plain_implicit, quoted_implicit bool, style yaml_scalar_style_t) bool { + *event = yaml_event_t{ + typ: yaml_SCALAR_EVENT, + anchor: anchor, + tag: tag, + value: value, + implicit: plain_implicit, + quoted_implicit: quoted_implicit, + style: yaml_style_t(style), + } + return true +} + +// Create SEQUENCE-START. +func yaml_sequence_start_event_initialize(event *yaml_event_t, anchor, tag []byte, implicit bool, style yaml_sequence_style_t) bool { + *event = yaml_event_t{ + typ: yaml_SEQUENCE_START_EVENT, + anchor: anchor, + tag: tag, + implicit: implicit, + style: yaml_style_t(style), + } + return true +} + +// Create SEQUENCE-END. +func yaml_sequence_end_event_initialize(event *yaml_event_t) bool { + *event = yaml_event_t{ + typ: yaml_SEQUENCE_END_EVENT, + } + return true +} + +// Create MAPPING-START. +func yaml_mapping_start_event_initialize(event *yaml_event_t, anchor, tag []byte, implicit bool, style yaml_mapping_style_t) { + *event = yaml_event_t{ + typ: yaml_MAPPING_START_EVENT, + anchor: anchor, + tag: tag, + implicit: implicit, + style: yaml_style_t(style), + } +} + +// Create MAPPING-END. +func yaml_mapping_end_event_initialize(event *yaml_event_t) { + *event = yaml_event_t{ + typ: yaml_MAPPING_END_EVENT, + } +} + +// Destroy an event object. +func yaml_event_delete(event *yaml_event_t) { + *event = yaml_event_t{} +} + +///* +// * Create a document object. +// */ +// +//YAML_DECLARE(int) +//yaml_document_initialize(document *yaml_document_t, +// version_directive *yaml_version_directive_t, +// tag_directives_start *yaml_tag_directive_t, +// tag_directives_end *yaml_tag_directive_t, +// start_implicit int, end_implicit int) +//{ +// struct { +// error yaml_error_type_t +// } context +// struct { +// start *yaml_node_t +// end *yaml_node_t +// top *yaml_node_t +// } nodes = { NULL, NULL, NULL } +// version_directive_copy *yaml_version_directive_t = NULL +// struct { +// start *yaml_tag_directive_t +// end *yaml_tag_directive_t +// top *yaml_tag_directive_t +// } tag_directives_copy = { NULL, NULL, NULL } +// value yaml_tag_directive_t = { NULL, NULL } +// mark yaml_mark_t = { 0, 0, 0 } +// +// assert(document) // Non-NULL document object is expected. +// assert((tag_directives_start && tag_directives_end) || +// (tag_directives_start == tag_directives_end)) +// // Valid tag directives are expected. +// +// if (!STACK_INIT(&context, nodes, INITIAL_STACK_SIZE)) goto error +// +// if (version_directive) { +// version_directive_copy = yaml_malloc(sizeof(yaml_version_directive_t)) +// if (!version_directive_copy) goto error +// version_directive_copy.major = version_directive.major +// version_directive_copy.minor = version_directive.minor +// } +// +// if (tag_directives_start != tag_directives_end) { +// tag_directive *yaml_tag_directive_t +// if (!STACK_INIT(&context, tag_directives_copy, INITIAL_STACK_SIZE)) +// goto error +// for (tag_directive = tag_directives_start +// tag_directive != tag_directives_end; tag_directive ++) { +// assert(tag_directive.handle) +// assert(tag_directive.prefix) +// if (!yaml_check_utf8(tag_directive.handle, +// strlen((char *)tag_directive.handle))) +// goto error +// if (!yaml_check_utf8(tag_directive.prefix, +// strlen((char *)tag_directive.prefix))) +// goto error +// value.handle = yaml_strdup(tag_directive.handle) +// value.prefix = yaml_strdup(tag_directive.prefix) +// if (!value.handle || !value.prefix) goto error +// if (!PUSH(&context, tag_directives_copy, value)) +// goto error +// value.handle = NULL +// value.prefix = NULL +// } +// } +// +// DOCUMENT_INIT(*document, nodes.start, nodes.end, version_directive_copy, +// tag_directives_copy.start, tag_directives_copy.top, +// start_implicit, end_implicit, mark, mark) +// +// return 1 +// +//error: +// STACK_DEL(&context, nodes) +// yaml_free(version_directive_copy) +// while (!STACK_EMPTY(&context, tag_directives_copy)) { +// value yaml_tag_directive_t = POP(&context, tag_directives_copy) +// yaml_free(value.handle) +// yaml_free(value.prefix) +// } +// STACK_DEL(&context, tag_directives_copy) +// yaml_free(value.handle) +// yaml_free(value.prefix) +// +// return 0 +//} +// +///* +// * Destroy a document object. +// */ +// +//YAML_DECLARE(void) +//yaml_document_delete(document *yaml_document_t) +//{ +// struct { +// error yaml_error_type_t +// } context +// tag_directive *yaml_tag_directive_t +// +// context.error = YAML_NO_ERROR // Eliminate a compiler warning. +// +// assert(document) // Non-NULL document object is expected. +// +// while (!STACK_EMPTY(&context, document.nodes)) { +// node yaml_node_t = POP(&context, document.nodes) +// yaml_free(node.tag) +// switch (node.type) { +// case YAML_SCALAR_NODE: +// yaml_free(node.data.scalar.value) +// break +// case YAML_SEQUENCE_NODE: +// STACK_DEL(&context, node.data.sequence.items) +// break +// case YAML_MAPPING_NODE: +// STACK_DEL(&context, node.data.mapping.pairs) +// break +// default: +// assert(0) // Should not happen. +// } +// } +// STACK_DEL(&context, document.nodes) +// +// yaml_free(document.version_directive) +// for (tag_directive = document.tag_directives.start +// tag_directive != document.tag_directives.end +// tag_directive++) { +// yaml_free(tag_directive.handle) +// yaml_free(tag_directive.prefix) +// } +// yaml_free(document.tag_directives.start) +// +// memset(document, 0, sizeof(yaml_document_t)) +//} +// +///** +// * Get a document node. +// */ +// +//YAML_DECLARE(yaml_node_t *) +//yaml_document_get_node(document *yaml_document_t, index int) +//{ +// assert(document) // Non-NULL document object is expected. +// +// if (index > 0 && document.nodes.start + index <= document.nodes.top) { +// return document.nodes.start + index - 1 +// } +// return NULL +//} +// +///** +// * Get the root object. +// */ +// +//YAML_DECLARE(yaml_node_t *) +//yaml_document_get_root_node(document *yaml_document_t) +//{ +// assert(document) // Non-NULL document object is expected. +// +// if (document.nodes.top != document.nodes.start) { +// return document.nodes.start +// } +// return NULL +//} +// +///* +// * Add a scalar node to a document. +// */ +// +//YAML_DECLARE(int) +//yaml_document_add_scalar(document *yaml_document_t, +// tag *yaml_char_t, value *yaml_char_t, length int, +// style yaml_scalar_style_t) +//{ +// struct { +// error yaml_error_type_t +// } context +// mark yaml_mark_t = { 0, 0, 0 } +// tag_copy *yaml_char_t = NULL +// value_copy *yaml_char_t = NULL +// node yaml_node_t +// +// assert(document) // Non-NULL document object is expected. +// assert(value) // Non-NULL value is expected. +// +// if (!tag) { +// tag = (yaml_char_t *)YAML_DEFAULT_SCALAR_TAG +// } +// +// if (!yaml_check_utf8(tag, strlen((char *)tag))) goto error +// tag_copy = yaml_strdup(tag) +// if (!tag_copy) goto error +// +// if (length < 0) { +// length = strlen((char *)value) +// } +// +// if (!yaml_check_utf8(value, length)) goto error +// value_copy = yaml_malloc(length+1) +// if (!value_copy) goto error +// memcpy(value_copy, value, length) +// value_copy[length] = '\0' +// +// SCALAR_NODE_INIT(node, tag_copy, value_copy, length, style, mark, mark) +// if (!PUSH(&context, document.nodes, node)) goto error +// +// return document.nodes.top - document.nodes.start +// +//error: +// yaml_free(tag_copy) +// yaml_free(value_copy) +// +// return 0 +//} +// +///* +// * Add a sequence node to a document. +// */ +// +//YAML_DECLARE(int) +//yaml_document_add_sequence(document *yaml_document_t, +// tag *yaml_char_t, style yaml_sequence_style_t) +//{ +// struct { +// error yaml_error_type_t +// } context +// mark yaml_mark_t = { 0, 0, 0 } +// tag_copy *yaml_char_t = NULL +// struct { +// start *yaml_node_item_t +// end *yaml_node_item_t +// top *yaml_node_item_t +// } items = { NULL, NULL, NULL } +// node yaml_node_t +// +// assert(document) // Non-NULL document object is expected. +// +// if (!tag) { +// tag = (yaml_char_t *)YAML_DEFAULT_SEQUENCE_TAG +// } +// +// if (!yaml_check_utf8(tag, strlen((char *)tag))) goto error +// tag_copy = yaml_strdup(tag) +// if (!tag_copy) goto error +// +// if (!STACK_INIT(&context, items, INITIAL_STACK_SIZE)) goto error +// +// SEQUENCE_NODE_INIT(node, tag_copy, items.start, items.end, +// style, mark, mark) +// if (!PUSH(&context, document.nodes, node)) goto error +// +// return document.nodes.top - document.nodes.start +// +//error: +// STACK_DEL(&context, items) +// yaml_free(tag_copy) +// +// return 0 +//} +// +///* +// * Add a mapping node to a document. +// */ +// +//YAML_DECLARE(int) +//yaml_document_add_mapping(document *yaml_document_t, +// tag *yaml_char_t, style yaml_mapping_style_t) +//{ +// struct { +// error yaml_error_type_t +// } context +// mark yaml_mark_t = { 0, 0, 0 } +// tag_copy *yaml_char_t = NULL +// struct { +// start *yaml_node_pair_t +// end *yaml_node_pair_t +// top *yaml_node_pair_t +// } pairs = { NULL, NULL, NULL } +// node yaml_node_t +// +// assert(document) // Non-NULL document object is expected. +// +// if (!tag) { +// tag = (yaml_char_t *)YAML_DEFAULT_MAPPING_TAG +// } +// +// if (!yaml_check_utf8(tag, strlen((char *)tag))) goto error +// tag_copy = yaml_strdup(tag) +// if (!tag_copy) goto error +// +// if (!STACK_INIT(&context, pairs, INITIAL_STACK_SIZE)) goto error +// +// MAPPING_NODE_INIT(node, tag_copy, pairs.start, pairs.end, +// style, mark, mark) +// if (!PUSH(&context, document.nodes, node)) goto error +// +// return document.nodes.top - document.nodes.start +// +//error: +// STACK_DEL(&context, pairs) +// yaml_free(tag_copy) +// +// return 0 +//} +// +///* +// * Append an item to a sequence node. +// */ +// +//YAML_DECLARE(int) +//yaml_document_append_sequence_item(document *yaml_document_t, +// sequence int, item int) +//{ +// struct { +// error yaml_error_type_t +// } context +// +// assert(document) // Non-NULL document is required. +// assert(sequence > 0 +// && document.nodes.start + sequence <= document.nodes.top) +// // Valid sequence id is required. +// assert(document.nodes.start[sequence-1].type == YAML_SEQUENCE_NODE) +// // A sequence node is required. +// assert(item > 0 && document.nodes.start + item <= document.nodes.top) +// // Valid item id is required. +// +// if (!PUSH(&context, +// document.nodes.start[sequence-1].data.sequence.items, item)) +// return 0 +// +// return 1 +//} +// +///* +// * Append a pair of a key and a value to a mapping node. +// */ +// +//YAML_DECLARE(int) +//yaml_document_append_mapping_pair(document *yaml_document_t, +// mapping int, key int, value int) +//{ +// struct { +// error yaml_error_type_t +// } context +// +// pair yaml_node_pair_t +// +// assert(document) // Non-NULL document is required. +// assert(mapping > 0 +// && document.nodes.start + mapping <= document.nodes.top) +// // Valid mapping id is required. +// assert(document.nodes.start[mapping-1].type == YAML_MAPPING_NODE) +// // A mapping node is required. +// assert(key > 0 && document.nodes.start + key <= document.nodes.top) +// // Valid key id is required. +// assert(value > 0 && document.nodes.start + value <= document.nodes.top) +// // Valid value id is required. +// +// pair.key = key +// pair.value = value +// +// if (!PUSH(&context, +// document.nodes.start[mapping-1].data.mapping.pairs, pair)) +// return 0 +// +// return 1 +//} +// +// diff --git a/vendor/gopkg.in/yaml.v3/decode.go b/vendor/gopkg.in/yaml.v3/decode.go new file mode 100644 index 000000000..0173b6982 --- /dev/null +++ b/vendor/gopkg.in/yaml.v3/decode.go @@ -0,0 +1,1000 @@ +// +// Copyright (c) 2011-2019 Canonical Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package yaml + +import ( + "encoding" + "encoding/base64" + "fmt" + "io" + "math" + "reflect" + "strconv" + "time" +) + +// ---------------------------------------------------------------------------- +// Parser, produces a node tree out of a libyaml event stream. + +type parser struct { + parser yaml_parser_t + event yaml_event_t + doc *Node + anchors map[string]*Node + doneInit bool + textless bool +} + +func newParser(b []byte) *parser { + p := parser{} + if !yaml_parser_initialize(&p.parser) { + panic("failed to initialize YAML emitter") + } + if len(b) == 0 { + b = []byte{'\n'} + } + yaml_parser_set_input_string(&p.parser, b) + return &p +} + +func newParserFromReader(r io.Reader) *parser { + p := parser{} + if !yaml_parser_initialize(&p.parser) { + panic("failed to initialize YAML emitter") + } + yaml_parser_set_input_reader(&p.parser, r) + return &p +} + +func (p *parser) init() { + if p.doneInit { + return + } + p.anchors = make(map[string]*Node) + p.expect(yaml_STREAM_START_EVENT) + p.doneInit = true +} + +func (p *parser) destroy() { + if p.event.typ != yaml_NO_EVENT { + yaml_event_delete(&p.event) + } + yaml_parser_delete(&p.parser) +} + +// expect consumes an event from the event stream and +// checks that it's of the expected type. +func (p *parser) expect(e yaml_event_type_t) { + if p.event.typ == yaml_NO_EVENT { + if !yaml_parser_parse(&p.parser, &p.event) { + p.fail() + } + } + if p.event.typ == yaml_STREAM_END_EVENT { + failf("attempted to go past the end of stream; corrupted value?") + } + if p.event.typ != e { + p.parser.problem = fmt.Sprintf("expected %s event but got %s", e, p.event.typ) + p.fail() + } + yaml_event_delete(&p.event) + p.event.typ = yaml_NO_EVENT +} + +// peek peeks at the next event in the event stream, +// puts the results into p.event and returns the event type. +func (p *parser) peek() yaml_event_type_t { + if p.event.typ != yaml_NO_EVENT { + return p.event.typ + } + // It's curious choice from the underlying API to generally return a + // positive result on success, but on this case return true in an error + // scenario. This was the source of bugs in the past (issue #666). + if !yaml_parser_parse(&p.parser, &p.event) || p.parser.error != yaml_NO_ERROR { + p.fail() + } + return p.event.typ +} + +func (p *parser) fail() { + var where string + var line int + if p.parser.context_mark.line != 0 { + line = p.parser.context_mark.line + // Scanner errors don't iterate line before returning error + if p.parser.error == yaml_SCANNER_ERROR { + line++ + } + } else if p.parser.problem_mark.line != 0 { + line = p.parser.problem_mark.line + // Scanner errors don't iterate line before returning error + if p.parser.error == yaml_SCANNER_ERROR { + line++ + } + } + if line != 0 { + where = "line " + strconv.Itoa(line) + ": " + } + var msg string + if len(p.parser.problem) > 0 { + msg = p.parser.problem + } else { + msg = "unknown problem parsing YAML content" + } + failf("%s%s", where, msg) +} + +func (p *parser) anchor(n *Node, anchor []byte) { + if anchor != nil { + n.Anchor = string(anchor) + p.anchors[n.Anchor] = n + } +} + +func (p *parser) parse() *Node { + p.init() + switch p.peek() { + case yaml_SCALAR_EVENT: + return p.scalar() + case yaml_ALIAS_EVENT: + return p.alias() + case yaml_MAPPING_START_EVENT: + return p.mapping() + case yaml_SEQUENCE_START_EVENT: + return p.sequence() + case yaml_DOCUMENT_START_EVENT: + return p.document() + case yaml_STREAM_END_EVENT: + // Happens when attempting to decode an empty buffer. + return nil + case yaml_TAIL_COMMENT_EVENT: + panic("internal error: unexpected tail comment event (please report)") + default: + panic("internal error: attempted to parse unknown event (please report): " + p.event.typ.String()) + } +} + +func (p *parser) node(kind Kind, defaultTag, tag, value string) *Node { + var style Style + if tag != "" && tag != "!" { + tag = shortTag(tag) + style = TaggedStyle + } else if defaultTag != "" { + tag = defaultTag + } else if kind == ScalarNode { + tag, _ = resolve("", value) + } + n := &Node{ + Kind: kind, + Tag: tag, + Value: value, + Style: style, + } + if !p.textless { + n.Line = p.event.start_mark.line + 1 + n.Column = p.event.start_mark.column + 1 + n.HeadComment = string(p.event.head_comment) + n.LineComment = string(p.event.line_comment) + n.FootComment = string(p.event.foot_comment) + } + return n +} + +func (p *parser) parseChild(parent *Node) *Node { + child := p.parse() + parent.Content = append(parent.Content, child) + return child +} + +func (p *parser) document() *Node { + n := p.node(DocumentNode, "", "", "") + p.doc = n + p.expect(yaml_DOCUMENT_START_EVENT) + p.parseChild(n) + if p.peek() == yaml_DOCUMENT_END_EVENT { + n.FootComment = string(p.event.foot_comment) + } + p.expect(yaml_DOCUMENT_END_EVENT) + return n +} + +func (p *parser) alias() *Node { + n := p.node(AliasNode, "", "", string(p.event.anchor)) + n.Alias = p.anchors[n.Value] + if n.Alias == nil { + failf("unknown anchor '%s' referenced", n.Value) + } + p.expect(yaml_ALIAS_EVENT) + return n +} + +func (p *parser) scalar() *Node { + var parsedStyle = p.event.scalar_style() + var nodeStyle Style + switch { + case parsedStyle&yaml_DOUBLE_QUOTED_SCALAR_STYLE != 0: + nodeStyle = DoubleQuotedStyle + case parsedStyle&yaml_SINGLE_QUOTED_SCALAR_STYLE != 0: + nodeStyle = SingleQuotedStyle + case parsedStyle&yaml_LITERAL_SCALAR_STYLE != 0: + nodeStyle = LiteralStyle + case parsedStyle&yaml_FOLDED_SCALAR_STYLE != 0: + nodeStyle = FoldedStyle + } + var nodeValue = string(p.event.value) + var nodeTag = string(p.event.tag) + var defaultTag string + if nodeStyle == 0 { + if nodeValue == "<<" { + defaultTag = mergeTag + } + } else { + defaultTag = strTag + } + n := p.node(ScalarNode, defaultTag, nodeTag, nodeValue) + n.Style |= nodeStyle + p.anchor(n, p.event.anchor) + p.expect(yaml_SCALAR_EVENT) + return n +} + +func (p *parser) sequence() *Node { + n := p.node(SequenceNode, seqTag, string(p.event.tag), "") + if p.event.sequence_style()&yaml_FLOW_SEQUENCE_STYLE != 0 { + n.Style |= FlowStyle + } + p.anchor(n, p.event.anchor) + p.expect(yaml_SEQUENCE_START_EVENT) + for p.peek() != yaml_SEQUENCE_END_EVENT { + p.parseChild(n) + } + n.LineComment = string(p.event.line_comment) + n.FootComment = string(p.event.foot_comment) + p.expect(yaml_SEQUENCE_END_EVENT) + return n +} + +func (p *parser) mapping() *Node { + n := p.node(MappingNode, mapTag, string(p.event.tag), "") + block := true + if p.event.mapping_style()&yaml_FLOW_MAPPING_STYLE != 0 { + block = false + n.Style |= FlowStyle + } + p.anchor(n, p.event.anchor) + p.expect(yaml_MAPPING_START_EVENT) + for p.peek() != yaml_MAPPING_END_EVENT { + k := p.parseChild(n) + if block && k.FootComment != "" { + // Must be a foot comment for the prior value when being dedented. + if len(n.Content) > 2 { + n.Content[len(n.Content)-3].FootComment = k.FootComment + k.FootComment = "" + } + } + v := p.parseChild(n) + if k.FootComment == "" && v.FootComment != "" { + k.FootComment = v.FootComment + v.FootComment = "" + } + if p.peek() == yaml_TAIL_COMMENT_EVENT { + if k.FootComment == "" { + k.FootComment = string(p.event.foot_comment) + } + p.expect(yaml_TAIL_COMMENT_EVENT) + } + } + n.LineComment = string(p.event.line_comment) + n.FootComment = string(p.event.foot_comment) + if n.Style&FlowStyle == 0 && n.FootComment != "" && len(n.Content) > 1 { + n.Content[len(n.Content)-2].FootComment = n.FootComment + n.FootComment = "" + } + p.expect(yaml_MAPPING_END_EVENT) + return n +} + +// ---------------------------------------------------------------------------- +// Decoder, unmarshals a node into a provided value. + +type decoder struct { + doc *Node + aliases map[*Node]bool + terrors []string + + stringMapType reflect.Type + generalMapType reflect.Type + + knownFields bool + uniqueKeys bool + decodeCount int + aliasCount int + aliasDepth int + + mergedFields map[interface{}]bool +} + +var ( + nodeType = reflect.TypeOf(Node{}) + durationType = reflect.TypeOf(time.Duration(0)) + stringMapType = reflect.TypeOf(map[string]interface{}{}) + generalMapType = reflect.TypeOf(map[interface{}]interface{}{}) + ifaceType = generalMapType.Elem() + timeType = reflect.TypeOf(time.Time{}) + ptrTimeType = reflect.TypeOf(&time.Time{}) +) + +func newDecoder() *decoder { + d := &decoder{ + stringMapType: stringMapType, + generalMapType: generalMapType, + uniqueKeys: true, + } + d.aliases = make(map[*Node]bool) + return d +} + +func (d *decoder) terror(n *Node, tag string, out reflect.Value) { + if n.Tag != "" { + tag = n.Tag + } + value := n.Value + if tag != seqTag && tag != mapTag { + if len(value) > 10 { + value = " `" + value[:7] + "...`" + } else { + value = " `" + value + "`" + } + } + d.terrors = append(d.terrors, fmt.Sprintf("line %d: cannot unmarshal %s%s into %s", n.Line, shortTag(tag), value, out.Type())) +} + +func (d *decoder) callUnmarshaler(n *Node, u Unmarshaler) (good bool) { + err := u.UnmarshalYAML(n) + if e, ok := err.(*TypeError); ok { + d.terrors = append(d.terrors, e.Errors...) + return false + } + if err != nil { + fail(err) + } + return true +} + +func (d *decoder) callObsoleteUnmarshaler(n *Node, u obsoleteUnmarshaler) (good bool) { + terrlen := len(d.terrors) + err := u.UnmarshalYAML(func(v interface{}) (err error) { + defer handleErr(&err) + d.unmarshal(n, reflect.ValueOf(v)) + if len(d.terrors) > terrlen { + issues := d.terrors[terrlen:] + d.terrors = d.terrors[:terrlen] + return &TypeError{issues} + } + return nil + }) + if e, ok := err.(*TypeError); ok { + d.terrors = append(d.terrors, e.Errors...) + return false + } + if err != nil { + fail(err) + } + return true +} + +// d.prepare initializes and dereferences pointers and calls UnmarshalYAML +// if a value is found to implement it. +// It returns the initialized and dereferenced out value, whether +// unmarshalling was already done by UnmarshalYAML, and if so whether +// its types unmarshalled appropriately. +// +// If n holds a null value, prepare returns before doing anything. +func (d *decoder) prepare(n *Node, out reflect.Value) (newout reflect.Value, unmarshaled, good bool) { + if n.ShortTag() == nullTag { + return out, false, false + } + again := true + for again { + again = false + if out.Kind() == reflect.Ptr { + if out.IsNil() { + out.Set(reflect.New(out.Type().Elem())) + } + out = out.Elem() + again = true + } + if out.CanAddr() { + outi := out.Addr().Interface() + if u, ok := outi.(Unmarshaler); ok { + good = d.callUnmarshaler(n, u) + return out, true, good + } + if u, ok := outi.(obsoleteUnmarshaler); ok { + good = d.callObsoleteUnmarshaler(n, u) + return out, true, good + } + } + } + return out, false, false +} + +func (d *decoder) fieldByIndex(n *Node, v reflect.Value, index []int) (field reflect.Value) { + if n.ShortTag() == nullTag { + return reflect.Value{} + } + for _, num := range index { + for { + if v.Kind() == reflect.Ptr { + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + v = v.Elem() + continue + } + break + } + v = v.Field(num) + } + return v +} + +const ( + // 400,000 decode operations is ~500kb of dense object declarations, or + // ~5kb of dense object declarations with 10000% alias expansion + alias_ratio_range_low = 400000 + + // 4,000,000 decode operations is ~5MB of dense object declarations, or + // ~4.5MB of dense object declarations with 10% alias expansion + alias_ratio_range_high = 4000000 + + // alias_ratio_range is the range over which we scale allowed alias ratios + alias_ratio_range = float64(alias_ratio_range_high - alias_ratio_range_low) +) + +func allowedAliasRatio(decodeCount int) float64 { + switch { + case decodeCount <= alias_ratio_range_low: + // allow 99% to come from alias expansion for small-to-medium documents + return 0.99 + case decodeCount >= alias_ratio_range_high: + // allow 10% to come from alias expansion for very large documents + return 0.10 + default: + // scale smoothly from 99% down to 10% over the range. + // this maps to 396,000 - 400,000 allowed alias-driven decodes over the range. + // 400,000 decode operations is ~100MB of allocations in worst-case scenarios (single-item maps). + return 0.99 - 0.89*(float64(decodeCount-alias_ratio_range_low)/alias_ratio_range) + } +} + +func (d *decoder) unmarshal(n *Node, out reflect.Value) (good bool) { + d.decodeCount++ + if d.aliasDepth > 0 { + d.aliasCount++ + } + if d.aliasCount > 100 && d.decodeCount > 1000 && float64(d.aliasCount)/float64(d.decodeCount) > allowedAliasRatio(d.decodeCount) { + failf("document contains excessive aliasing") + } + if out.Type() == nodeType { + out.Set(reflect.ValueOf(n).Elem()) + return true + } + switch n.Kind { + case DocumentNode: + return d.document(n, out) + case AliasNode: + return d.alias(n, out) + } + out, unmarshaled, good := d.prepare(n, out) + if unmarshaled { + return good + } + switch n.Kind { + case ScalarNode: + good = d.scalar(n, out) + case MappingNode: + good = d.mapping(n, out) + case SequenceNode: + good = d.sequence(n, out) + case 0: + if n.IsZero() { + return d.null(out) + } + fallthrough + default: + failf("cannot decode node with unknown kind %d", n.Kind) + } + return good +} + +func (d *decoder) document(n *Node, out reflect.Value) (good bool) { + if len(n.Content) == 1 { + d.doc = n + d.unmarshal(n.Content[0], out) + return true + } + return false +} + +func (d *decoder) alias(n *Node, out reflect.Value) (good bool) { + if d.aliases[n] { + // TODO this could actually be allowed in some circumstances. + failf("anchor '%s' value contains itself", n.Value) + } + d.aliases[n] = true + d.aliasDepth++ + good = d.unmarshal(n.Alias, out) + d.aliasDepth-- + delete(d.aliases, n) + return good +} + +var zeroValue reflect.Value + +func resetMap(out reflect.Value) { + for _, k := range out.MapKeys() { + out.SetMapIndex(k, zeroValue) + } +} + +func (d *decoder) null(out reflect.Value) bool { + if out.CanAddr() { + switch out.Kind() { + case reflect.Interface, reflect.Ptr, reflect.Map, reflect.Slice: + out.Set(reflect.Zero(out.Type())) + return true + } + } + return false +} + +func (d *decoder) scalar(n *Node, out reflect.Value) bool { + var tag string + var resolved interface{} + if n.indicatedString() { + tag = strTag + resolved = n.Value + } else { + tag, resolved = resolve(n.Tag, n.Value) + if tag == binaryTag { + data, err := base64.StdEncoding.DecodeString(resolved.(string)) + if err != nil { + failf("!!binary value contains invalid base64 data") + } + resolved = string(data) + } + } + if resolved == nil { + return d.null(out) + } + if resolvedv := reflect.ValueOf(resolved); out.Type() == resolvedv.Type() { + // We've resolved to exactly the type we want, so use that. + out.Set(resolvedv) + return true + } + // Perhaps we can use the value as a TextUnmarshaler to + // set its value. + if out.CanAddr() { + u, ok := out.Addr().Interface().(encoding.TextUnmarshaler) + if ok { + var text []byte + if tag == binaryTag { + text = []byte(resolved.(string)) + } else { + // We let any value be unmarshaled into TextUnmarshaler. + // That might be more lax than we'd like, but the + // TextUnmarshaler itself should bowl out any dubious values. + text = []byte(n.Value) + } + err := u.UnmarshalText(text) + if err != nil { + fail(err) + } + return true + } + } + switch out.Kind() { + case reflect.String: + if tag == binaryTag { + out.SetString(resolved.(string)) + return true + } + out.SetString(n.Value) + return true + case reflect.Interface: + out.Set(reflect.ValueOf(resolved)) + return true + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + // This used to work in v2, but it's very unfriendly. + isDuration := out.Type() == durationType + + switch resolved := resolved.(type) { + case int: + if !isDuration && !out.OverflowInt(int64(resolved)) { + out.SetInt(int64(resolved)) + return true + } + case int64: + if !isDuration && !out.OverflowInt(resolved) { + out.SetInt(resolved) + return true + } + case uint64: + if !isDuration && resolved <= math.MaxInt64 && !out.OverflowInt(int64(resolved)) { + out.SetInt(int64(resolved)) + return true + } + case float64: + if !isDuration && resolved <= math.MaxInt64 && !out.OverflowInt(int64(resolved)) { + out.SetInt(int64(resolved)) + return true + } + case string: + if out.Type() == durationType { + d, err := time.ParseDuration(resolved) + if err == nil { + out.SetInt(int64(d)) + return true + } + } + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + switch resolved := resolved.(type) { + case int: + if resolved >= 0 && !out.OverflowUint(uint64(resolved)) { + out.SetUint(uint64(resolved)) + return true + } + case int64: + if resolved >= 0 && !out.OverflowUint(uint64(resolved)) { + out.SetUint(uint64(resolved)) + return true + } + case uint64: + if !out.OverflowUint(uint64(resolved)) { + out.SetUint(uint64(resolved)) + return true + } + case float64: + if resolved <= math.MaxUint64 && !out.OverflowUint(uint64(resolved)) { + out.SetUint(uint64(resolved)) + return true + } + } + case reflect.Bool: + switch resolved := resolved.(type) { + case bool: + out.SetBool(resolved) + return true + case string: + // This offers some compatibility with the 1.1 spec (https://yaml.org/type/bool.html). + // It only works if explicitly attempting to unmarshal into a typed bool value. + switch resolved { + case "y", "Y", "yes", "Yes", "YES", "on", "On", "ON": + out.SetBool(true) + return true + case "n", "N", "no", "No", "NO", "off", "Off", "OFF": + out.SetBool(false) + return true + } + } + case reflect.Float32, reflect.Float64: + switch resolved := resolved.(type) { + case int: + out.SetFloat(float64(resolved)) + return true + case int64: + out.SetFloat(float64(resolved)) + return true + case uint64: + out.SetFloat(float64(resolved)) + return true + case float64: + out.SetFloat(resolved) + return true + } + case reflect.Struct: + if resolvedv := reflect.ValueOf(resolved); out.Type() == resolvedv.Type() { + out.Set(resolvedv) + return true + } + case reflect.Ptr: + panic("yaml internal error: please report the issue") + } + d.terror(n, tag, out) + return false +} + +func settableValueOf(i interface{}) reflect.Value { + v := reflect.ValueOf(i) + sv := reflect.New(v.Type()).Elem() + sv.Set(v) + return sv +} + +func (d *decoder) sequence(n *Node, out reflect.Value) (good bool) { + l := len(n.Content) + + var iface reflect.Value + switch out.Kind() { + case reflect.Slice: + out.Set(reflect.MakeSlice(out.Type(), l, l)) + case reflect.Array: + if l != out.Len() { + failf("invalid array: want %d elements but got %d", out.Len(), l) + } + case reflect.Interface: + // No type hints. Will have to use a generic sequence. + iface = out + out = settableValueOf(make([]interface{}, l)) + default: + d.terror(n, seqTag, out) + return false + } + et := out.Type().Elem() + + j := 0 + for i := 0; i < l; i++ { + e := reflect.New(et).Elem() + if ok := d.unmarshal(n.Content[i], e); ok { + out.Index(j).Set(e) + j++ + } + } + if out.Kind() != reflect.Array { + out.Set(out.Slice(0, j)) + } + if iface.IsValid() { + iface.Set(out) + } + return true +} + +func (d *decoder) mapping(n *Node, out reflect.Value) (good bool) { + l := len(n.Content) + if d.uniqueKeys { + nerrs := len(d.terrors) + for i := 0; i < l; i += 2 { + ni := n.Content[i] + for j := i + 2; j < l; j += 2 { + nj := n.Content[j] + if ni.Kind == nj.Kind && ni.Value == nj.Value { + d.terrors = append(d.terrors, fmt.Sprintf("line %d: mapping key %#v already defined at line %d", nj.Line, nj.Value, ni.Line)) + } + } + } + if len(d.terrors) > nerrs { + return false + } + } + switch out.Kind() { + case reflect.Struct: + return d.mappingStruct(n, out) + case reflect.Map: + // okay + case reflect.Interface: + iface := out + if isStringMap(n) { + out = reflect.MakeMap(d.stringMapType) + } else { + out = reflect.MakeMap(d.generalMapType) + } + iface.Set(out) + default: + d.terror(n, mapTag, out) + return false + } + + outt := out.Type() + kt := outt.Key() + et := outt.Elem() + + stringMapType := d.stringMapType + generalMapType := d.generalMapType + if outt.Elem() == ifaceType { + if outt.Key().Kind() == reflect.String { + d.stringMapType = outt + } else if outt.Key() == ifaceType { + d.generalMapType = outt + } + } + + mergedFields := d.mergedFields + d.mergedFields = nil + + var mergeNode *Node + + mapIsNew := false + if out.IsNil() { + out.Set(reflect.MakeMap(outt)) + mapIsNew = true + } + for i := 0; i < l; i += 2 { + if isMerge(n.Content[i]) { + mergeNode = n.Content[i+1] + continue + } + k := reflect.New(kt).Elem() + if d.unmarshal(n.Content[i], k) { + if mergedFields != nil { + ki := k.Interface() + if mergedFields[ki] { + continue + } + mergedFields[ki] = true + } + kkind := k.Kind() + if kkind == reflect.Interface { + kkind = k.Elem().Kind() + } + if kkind == reflect.Map || kkind == reflect.Slice { + failf("invalid map key: %#v", k.Interface()) + } + e := reflect.New(et).Elem() + if d.unmarshal(n.Content[i+1], e) || n.Content[i+1].ShortTag() == nullTag && (mapIsNew || !out.MapIndex(k).IsValid()) { + out.SetMapIndex(k, e) + } + } + } + + d.mergedFields = mergedFields + if mergeNode != nil { + d.merge(n, mergeNode, out) + } + + d.stringMapType = stringMapType + d.generalMapType = generalMapType + return true +} + +func isStringMap(n *Node) bool { + if n.Kind != MappingNode { + return false + } + l := len(n.Content) + for i := 0; i < l; i += 2 { + shortTag := n.Content[i].ShortTag() + if shortTag != strTag && shortTag != mergeTag { + return false + } + } + return true +} + +func (d *decoder) mappingStruct(n *Node, out reflect.Value) (good bool) { + sinfo, err := getStructInfo(out.Type()) + if err != nil { + panic(err) + } + + var inlineMap reflect.Value + var elemType reflect.Type + if sinfo.InlineMap != -1 { + inlineMap = out.Field(sinfo.InlineMap) + elemType = inlineMap.Type().Elem() + } + + for _, index := range sinfo.InlineUnmarshalers { + field := d.fieldByIndex(n, out, index) + d.prepare(n, field) + } + + mergedFields := d.mergedFields + d.mergedFields = nil + var mergeNode *Node + var doneFields []bool + if d.uniqueKeys { + doneFields = make([]bool, len(sinfo.FieldsList)) + } + name := settableValueOf("") + l := len(n.Content) + for i := 0; i < l; i += 2 { + ni := n.Content[i] + if isMerge(ni) { + mergeNode = n.Content[i+1] + continue + } + if !d.unmarshal(ni, name) { + continue + } + sname := name.String() + if mergedFields != nil { + if mergedFields[sname] { + continue + } + mergedFields[sname] = true + } + if info, ok := sinfo.FieldsMap[sname]; ok { + if d.uniqueKeys { + if doneFields[info.Id] { + d.terrors = append(d.terrors, fmt.Sprintf("line %d: field %s already set in type %s", ni.Line, name.String(), out.Type())) + continue + } + doneFields[info.Id] = true + } + var field reflect.Value + if info.Inline == nil { + field = out.Field(info.Num) + } else { + field = d.fieldByIndex(n, out, info.Inline) + } + d.unmarshal(n.Content[i+1], field) + } else if sinfo.InlineMap != -1 { + if inlineMap.IsNil() { + inlineMap.Set(reflect.MakeMap(inlineMap.Type())) + } + value := reflect.New(elemType).Elem() + d.unmarshal(n.Content[i+1], value) + inlineMap.SetMapIndex(name, value) + } else if d.knownFields { + d.terrors = append(d.terrors, fmt.Sprintf("line %d: field %s not found in type %s", ni.Line, name.String(), out.Type())) + } + } + + d.mergedFields = mergedFields + if mergeNode != nil { + d.merge(n, mergeNode, out) + } + return true +} + +func failWantMap() { + failf("map merge requires map or sequence of maps as the value") +} + +func (d *decoder) merge(parent *Node, merge *Node, out reflect.Value) { + mergedFields := d.mergedFields + if mergedFields == nil { + d.mergedFields = make(map[interface{}]bool) + for i := 0; i < len(parent.Content); i += 2 { + k := reflect.New(ifaceType).Elem() + if d.unmarshal(parent.Content[i], k) { + d.mergedFields[k.Interface()] = true + } + } + } + + switch merge.Kind { + case MappingNode: + d.unmarshal(merge, out) + case AliasNode: + if merge.Alias != nil && merge.Alias.Kind != MappingNode { + failWantMap() + } + d.unmarshal(merge, out) + case SequenceNode: + for i := 0; i < len(merge.Content); i++ { + ni := merge.Content[i] + if ni.Kind == AliasNode { + if ni.Alias != nil && ni.Alias.Kind != MappingNode { + failWantMap() + } + } else if ni.Kind != MappingNode { + failWantMap() + } + d.unmarshal(ni, out) + } + default: + failWantMap() + } + + d.mergedFields = mergedFields +} + +func isMerge(n *Node) bool { + return n.Kind == ScalarNode && n.Value == "<<" && (n.Tag == "" || n.Tag == "!" || shortTag(n.Tag) == mergeTag) +} diff --git a/vendor/gopkg.in/yaml.v3/emitterc.go b/vendor/gopkg.in/yaml.v3/emitterc.go new file mode 100644 index 000000000..0f47c9ca8 --- /dev/null +++ b/vendor/gopkg.in/yaml.v3/emitterc.go @@ -0,0 +1,2020 @@ +// +// Copyright (c) 2011-2019 Canonical Ltd +// Copyright (c) 2006-2010 Kirill Simonov +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of +// this software and associated documentation files (the "Software"), to deal in +// the Software without restriction, including without limitation the rights to +// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +// of the Software, and to permit persons to whom the Software is furnished to do +// so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package yaml + +import ( + "bytes" + "fmt" +) + +// Flush the buffer if needed. +func flush(emitter *yaml_emitter_t) bool { + if emitter.buffer_pos+5 >= len(emitter.buffer) { + return yaml_emitter_flush(emitter) + } + return true +} + +// Put a character to the output buffer. +func put(emitter *yaml_emitter_t, value byte) bool { + if emitter.buffer_pos+5 >= len(emitter.buffer) && !yaml_emitter_flush(emitter) { + return false + } + emitter.buffer[emitter.buffer_pos] = value + emitter.buffer_pos++ + emitter.column++ + return true +} + +// Put a line break to the output buffer. +func put_break(emitter *yaml_emitter_t) bool { + if emitter.buffer_pos+5 >= len(emitter.buffer) && !yaml_emitter_flush(emitter) { + return false + } + switch emitter.line_break { + case yaml_CR_BREAK: + emitter.buffer[emitter.buffer_pos] = '\r' + emitter.buffer_pos += 1 + case yaml_LN_BREAK: + emitter.buffer[emitter.buffer_pos] = '\n' + emitter.buffer_pos += 1 + case yaml_CRLN_BREAK: + emitter.buffer[emitter.buffer_pos+0] = '\r' + emitter.buffer[emitter.buffer_pos+1] = '\n' + emitter.buffer_pos += 2 + default: + panic("unknown line break setting") + } + if emitter.column == 0 { + emitter.space_above = true + } + emitter.column = 0 + emitter.line++ + // [Go] Do this here and below and drop from everywhere else (see commented lines). + emitter.indention = true + return true +} + +// Copy a character from a string into buffer. +func write(emitter *yaml_emitter_t, s []byte, i *int) bool { + if emitter.buffer_pos+5 >= len(emitter.buffer) && !yaml_emitter_flush(emitter) { + return false + } + p := emitter.buffer_pos + w := width(s[*i]) + switch w { + case 4: + emitter.buffer[p+3] = s[*i+3] + fallthrough + case 3: + emitter.buffer[p+2] = s[*i+2] + fallthrough + case 2: + emitter.buffer[p+1] = s[*i+1] + fallthrough + case 1: + emitter.buffer[p+0] = s[*i+0] + default: + panic("unknown character width") + } + emitter.column++ + emitter.buffer_pos += w + *i += w + return true +} + +// Write a whole string into buffer. +func write_all(emitter *yaml_emitter_t, s []byte) bool { + for i := 0; i < len(s); { + if !write(emitter, s, &i) { + return false + } + } + return true +} + +// Copy a line break character from a string into buffer. +func write_break(emitter *yaml_emitter_t, s []byte, i *int) bool { + if s[*i] == '\n' { + if !put_break(emitter) { + return false + } + *i++ + } else { + if !write(emitter, s, i) { + return false + } + if emitter.column == 0 { + emitter.space_above = true + } + emitter.column = 0 + emitter.line++ + // [Go] Do this here and above and drop from everywhere else (see commented lines). + emitter.indention = true + } + return true +} + +// Set an emitter error and return false. +func yaml_emitter_set_emitter_error(emitter *yaml_emitter_t, problem string) bool { + emitter.error = yaml_EMITTER_ERROR + emitter.problem = problem + return false +} + +// Emit an event. +func yaml_emitter_emit(emitter *yaml_emitter_t, event *yaml_event_t) bool { + emitter.events = append(emitter.events, *event) + for !yaml_emitter_need_more_events(emitter) { + event := &emitter.events[emitter.events_head] + if !yaml_emitter_analyze_event(emitter, event) { + return false + } + if !yaml_emitter_state_machine(emitter, event) { + return false + } + yaml_event_delete(event) + emitter.events_head++ + } + return true +} + +// Check if we need to accumulate more events before emitting. +// +// We accumulate extra +// - 1 event for DOCUMENT-START +// - 2 events for SEQUENCE-START +// - 3 events for MAPPING-START +// +func yaml_emitter_need_more_events(emitter *yaml_emitter_t) bool { + if emitter.events_head == len(emitter.events) { + return true + } + var accumulate int + switch emitter.events[emitter.events_head].typ { + case yaml_DOCUMENT_START_EVENT: + accumulate = 1 + break + case yaml_SEQUENCE_START_EVENT: + accumulate = 2 + break + case yaml_MAPPING_START_EVENT: + accumulate = 3 + break + default: + return false + } + if len(emitter.events)-emitter.events_head > accumulate { + return false + } + var level int + for i := emitter.events_head; i < len(emitter.events); i++ { + switch emitter.events[i].typ { + case yaml_STREAM_START_EVENT, yaml_DOCUMENT_START_EVENT, yaml_SEQUENCE_START_EVENT, yaml_MAPPING_START_EVENT: + level++ + case yaml_STREAM_END_EVENT, yaml_DOCUMENT_END_EVENT, yaml_SEQUENCE_END_EVENT, yaml_MAPPING_END_EVENT: + level-- + } + if level == 0 { + return false + } + } + return true +} + +// Append a directive to the directives stack. +func yaml_emitter_append_tag_directive(emitter *yaml_emitter_t, value *yaml_tag_directive_t, allow_duplicates bool) bool { + for i := 0; i < len(emitter.tag_directives); i++ { + if bytes.Equal(value.handle, emitter.tag_directives[i].handle) { + if allow_duplicates { + return true + } + return yaml_emitter_set_emitter_error(emitter, "duplicate %TAG directive") + } + } + + // [Go] Do we actually need to copy this given garbage collection + // and the lack of deallocating destructors? + tag_copy := yaml_tag_directive_t{ + handle: make([]byte, len(value.handle)), + prefix: make([]byte, len(value.prefix)), + } + copy(tag_copy.handle, value.handle) + copy(tag_copy.prefix, value.prefix) + emitter.tag_directives = append(emitter.tag_directives, tag_copy) + return true +} + +// Increase the indentation level. +func yaml_emitter_increase_indent(emitter *yaml_emitter_t, flow, indentless bool) bool { + emitter.indents = append(emitter.indents, emitter.indent) + if emitter.indent < 0 { + if flow { + emitter.indent = emitter.best_indent + } else { + emitter.indent = 0 + } + } else if !indentless { + // [Go] This was changed so that indentations are more regular. + if emitter.states[len(emitter.states)-1] == yaml_EMIT_BLOCK_SEQUENCE_ITEM_STATE { + // The first indent inside a sequence will just skip the "- " indicator. + emitter.indent += 2 + } else { + // Everything else aligns to the chosen indentation. + emitter.indent = emitter.best_indent*((emitter.indent+emitter.best_indent)/emitter.best_indent) + } + } + return true +} + +// State dispatcher. +func yaml_emitter_state_machine(emitter *yaml_emitter_t, event *yaml_event_t) bool { + switch emitter.state { + default: + case yaml_EMIT_STREAM_START_STATE: + return yaml_emitter_emit_stream_start(emitter, event) + + case yaml_EMIT_FIRST_DOCUMENT_START_STATE: + return yaml_emitter_emit_document_start(emitter, event, true) + + case yaml_EMIT_DOCUMENT_START_STATE: + return yaml_emitter_emit_document_start(emitter, event, false) + + case yaml_EMIT_DOCUMENT_CONTENT_STATE: + return yaml_emitter_emit_document_content(emitter, event) + + case yaml_EMIT_DOCUMENT_END_STATE: + return yaml_emitter_emit_document_end(emitter, event) + + case yaml_EMIT_FLOW_SEQUENCE_FIRST_ITEM_STATE: + return yaml_emitter_emit_flow_sequence_item(emitter, event, true, false) + + case yaml_EMIT_FLOW_SEQUENCE_TRAIL_ITEM_STATE: + return yaml_emitter_emit_flow_sequence_item(emitter, event, false, true) + + case yaml_EMIT_FLOW_SEQUENCE_ITEM_STATE: + return yaml_emitter_emit_flow_sequence_item(emitter, event, false, false) + + case yaml_EMIT_FLOW_MAPPING_FIRST_KEY_STATE: + return yaml_emitter_emit_flow_mapping_key(emitter, event, true, false) + + case yaml_EMIT_FLOW_MAPPING_TRAIL_KEY_STATE: + return yaml_emitter_emit_flow_mapping_key(emitter, event, false, true) + + case yaml_EMIT_FLOW_MAPPING_KEY_STATE: + return yaml_emitter_emit_flow_mapping_key(emitter, event, false, false) + + case yaml_EMIT_FLOW_MAPPING_SIMPLE_VALUE_STATE: + return yaml_emitter_emit_flow_mapping_value(emitter, event, true) + + case yaml_EMIT_FLOW_MAPPING_VALUE_STATE: + return yaml_emitter_emit_flow_mapping_value(emitter, event, false) + + case yaml_EMIT_BLOCK_SEQUENCE_FIRST_ITEM_STATE: + return yaml_emitter_emit_block_sequence_item(emitter, event, true) + + case yaml_EMIT_BLOCK_SEQUENCE_ITEM_STATE: + return yaml_emitter_emit_block_sequence_item(emitter, event, false) + + case yaml_EMIT_BLOCK_MAPPING_FIRST_KEY_STATE: + return yaml_emitter_emit_block_mapping_key(emitter, event, true) + + case yaml_EMIT_BLOCK_MAPPING_KEY_STATE: + return yaml_emitter_emit_block_mapping_key(emitter, event, false) + + case yaml_EMIT_BLOCK_MAPPING_SIMPLE_VALUE_STATE: + return yaml_emitter_emit_block_mapping_value(emitter, event, true) + + case yaml_EMIT_BLOCK_MAPPING_VALUE_STATE: + return yaml_emitter_emit_block_mapping_value(emitter, event, false) + + case yaml_EMIT_END_STATE: + return yaml_emitter_set_emitter_error(emitter, "expected nothing after STREAM-END") + } + panic("invalid emitter state") +} + +// Expect STREAM-START. +func yaml_emitter_emit_stream_start(emitter *yaml_emitter_t, event *yaml_event_t) bool { + if event.typ != yaml_STREAM_START_EVENT { + return yaml_emitter_set_emitter_error(emitter, "expected STREAM-START") + } + if emitter.encoding == yaml_ANY_ENCODING { + emitter.encoding = event.encoding + if emitter.encoding == yaml_ANY_ENCODING { + emitter.encoding = yaml_UTF8_ENCODING + } + } + if emitter.best_indent < 2 || emitter.best_indent > 9 { + emitter.best_indent = 2 + } + if emitter.best_width >= 0 && emitter.best_width <= emitter.best_indent*2 { + emitter.best_width = 80 + } + if emitter.best_width < 0 { + emitter.best_width = 1<<31 - 1 + } + if emitter.line_break == yaml_ANY_BREAK { + emitter.line_break = yaml_LN_BREAK + } + + emitter.indent = -1 + emitter.line = 0 + emitter.column = 0 + emitter.whitespace = true + emitter.indention = true + emitter.space_above = true + emitter.foot_indent = -1 + + if emitter.encoding != yaml_UTF8_ENCODING { + if !yaml_emitter_write_bom(emitter) { + return false + } + } + emitter.state = yaml_EMIT_FIRST_DOCUMENT_START_STATE + return true +} + +// Expect DOCUMENT-START or STREAM-END. +func yaml_emitter_emit_document_start(emitter *yaml_emitter_t, event *yaml_event_t, first bool) bool { + + if event.typ == yaml_DOCUMENT_START_EVENT { + + if event.version_directive != nil { + if !yaml_emitter_analyze_version_directive(emitter, event.version_directive) { + return false + } + } + + for i := 0; i < len(event.tag_directives); i++ { + tag_directive := &event.tag_directives[i] + if !yaml_emitter_analyze_tag_directive(emitter, tag_directive) { + return false + } + if !yaml_emitter_append_tag_directive(emitter, tag_directive, false) { + return false + } + } + + for i := 0; i < len(default_tag_directives); i++ { + tag_directive := &default_tag_directives[i] + if !yaml_emitter_append_tag_directive(emitter, tag_directive, true) { + return false + } + } + + implicit := event.implicit + if !first || emitter.canonical { + implicit = false + } + + if emitter.open_ended && (event.version_directive != nil || len(event.tag_directives) > 0) { + if !yaml_emitter_write_indicator(emitter, []byte("..."), true, false, false) { + return false + } + if !yaml_emitter_write_indent(emitter) { + return false + } + } + + if event.version_directive != nil { + implicit = false + if !yaml_emitter_write_indicator(emitter, []byte("%YAML"), true, false, false) { + return false + } + if !yaml_emitter_write_indicator(emitter, []byte("1.1"), true, false, false) { + return false + } + if !yaml_emitter_write_indent(emitter) { + return false + } + } + + if len(event.tag_directives) > 0 { + implicit = false + for i := 0; i < len(event.tag_directives); i++ { + tag_directive := &event.tag_directives[i] + if !yaml_emitter_write_indicator(emitter, []byte("%TAG"), true, false, false) { + return false + } + if !yaml_emitter_write_tag_handle(emitter, tag_directive.handle) { + return false + } + if !yaml_emitter_write_tag_content(emitter, tag_directive.prefix, true) { + return false + } + if !yaml_emitter_write_indent(emitter) { + return false + } + } + } + + if yaml_emitter_check_empty_document(emitter) { + implicit = false + } + if !implicit { + if !yaml_emitter_write_indent(emitter) { + return false + } + if !yaml_emitter_write_indicator(emitter, []byte("---"), true, false, false) { + return false + } + if emitter.canonical || true { + if !yaml_emitter_write_indent(emitter) { + return false + } + } + } + + if len(emitter.head_comment) > 0 { + if !yaml_emitter_process_head_comment(emitter) { + return false + } + if !put_break(emitter) { + return false + } + } + + emitter.state = yaml_EMIT_DOCUMENT_CONTENT_STATE + return true + } + + if event.typ == yaml_STREAM_END_EVENT { + if emitter.open_ended { + if !yaml_emitter_write_indicator(emitter, []byte("..."), true, false, false) { + return false + } + if !yaml_emitter_write_indent(emitter) { + return false + } + } + if !yaml_emitter_flush(emitter) { + return false + } + emitter.state = yaml_EMIT_END_STATE + return true + } + + return yaml_emitter_set_emitter_error(emitter, "expected DOCUMENT-START or STREAM-END") +} + +// Expect the root node. +func yaml_emitter_emit_document_content(emitter *yaml_emitter_t, event *yaml_event_t) bool { + emitter.states = append(emitter.states, yaml_EMIT_DOCUMENT_END_STATE) + + if !yaml_emitter_process_head_comment(emitter) { + return false + } + if !yaml_emitter_emit_node(emitter, event, true, false, false, false) { + return false + } + if !yaml_emitter_process_line_comment(emitter) { + return false + } + if !yaml_emitter_process_foot_comment(emitter) { + return false + } + return true +} + +// Expect DOCUMENT-END. +func yaml_emitter_emit_document_end(emitter *yaml_emitter_t, event *yaml_event_t) bool { + if event.typ != yaml_DOCUMENT_END_EVENT { + return yaml_emitter_set_emitter_error(emitter, "expected DOCUMENT-END") + } + // [Go] Force document foot separation. + emitter.foot_indent = 0 + if !yaml_emitter_process_foot_comment(emitter) { + return false + } + emitter.foot_indent = -1 + if !yaml_emitter_write_indent(emitter) { + return false + } + if !event.implicit { + // [Go] Allocate the slice elsewhere. + if !yaml_emitter_write_indicator(emitter, []byte("..."), true, false, false) { + return false + } + if !yaml_emitter_write_indent(emitter) { + return false + } + } + if !yaml_emitter_flush(emitter) { + return false + } + emitter.state = yaml_EMIT_DOCUMENT_START_STATE + emitter.tag_directives = emitter.tag_directives[:0] + return true +} + +// Expect a flow item node. +func yaml_emitter_emit_flow_sequence_item(emitter *yaml_emitter_t, event *yaml_event_t, first, trail bool) bool { + if first { + if !yaml_emitter_write_indicator(emitter, []byte{'['}, true, true, false) { + return false + } + if !yaml_emitter_increase_indent(emitter, true, false) { + return false + } + emitter.flow_level++ + } + + if event.typ == yaml_SEQUENCE_END_EVENT { + if emitter.canonical && !first && !trail { + if !yaml_emitter_write_indicator(emitter, []byte{','}, false, false, false) { + return false + } + } + emitter.flow_level-- + emitter.indent = emitter.indents[len(emitter.indents)-1] + emitter.indents = emitter.indents[:len(emitter.indents)-1] + if emitter.column == 0 || emitter.canonical && !first { + if !yaml_emitter_write_indent(emitter) { + return false + } + } + if !yaml_emitter_write_indicator(emitter, []byte{']'}, false, false, false) { + return false + } + if !yaml_emitter_process_line_comment(emitter) { + return false + } + if !yaml_emitter_process_foot_comment(emitter) { + return false + } + emitter.state = emitter.states[len(emitter.states)-1] + emitter.states = emitter.states[:len(emitter.states)-1] + + return true + } + + if !first && !trail { + if !yaml_emitter_write_indicator(emitter, []byte{','}, false, false, false) { + return false + } + } + + if !yaml_emitter_process_head_comment(emitter) { + return false + } + if emitter.column == 0 { + if !yaml_emitter_write_indent(emitter) { + return false + } + } + + if emitter.canonical || emitter.column > emitter.best_width { + if !yaml_emitter_write_indent(emitter) { + return false + } + } + if len(emitter.line_comment)+len(emitter.foot_comment)+len(emitter.tail_comment) > 0 { + emitter.states = append(emitter.states, yaml_EMIT_FLOW_SEQUENCE_TRAIL_ITEM_STATE) + } else { + emitter.states = append(emitter.states, yaml_EMIT_FLOW_SEQUENCE_ITEM_STATE) + } + if !yaml_emitter_emit_node(emitter, event, false, true, false, false) { + return false + } + if len(emitter.line_comment)+len(emitter.foot_comment)+len(emitter.tail_comment) > 0 { + if !yaml_emitter_write_indicator(emitter, []byte{','}, false, false, false) { + return false + } + } + if !yaml_emitter_process_line_comment(emitter) { + return false + } + if !yaml_emitter_process_foot_comment(emitter) { + return false + } + return true +} + +// Expect a flow key node. +func yaml_emitter_emit_flow_mapping_key(emitter *yaml_emitter_t, event *yaml_event_t, first, trail bool) bool { + if first { + if !yaml_emitter_write_indicator(emitter, []byte{'{'}, true, true, false) { + return false + } + if !yaml_emitter_increase_indent(emitter, true, false) { + return false + } + emitter.flow_level++ + } + + if event.typ == yaml_MAPPING_END_EVENT { + if (emitter.canonical || len(emitter.head_comment)+len(emitter.foot_comment)+len(emitter.tail_comment) > 0) && !first && !trail { + if !yaml_emitter_write_indicator(emitter, []byte{','}, false, false, false) { + return false + } + } + if !yaml_emitter_process_head_comment(emitter) { + return false + } + emitter.flow_level-- + emitter.indent = emitter.indents[len(emitter.indents)-1] + emitter.indents = emitter.indents[:len(emitter.indents)-1] + if emitter.canonical && !first { + if !yaml_emitter_write_indent(emitter) { + return false + } + } + if !yaml_emitter_write_indicator(emitter, []byte{'}'}, false, false, false) { + return false + } + if !yaml_emitter_process_line_comment(emitter) { + return false + } + if !yaml_emitter_process_foot_comment(emitter) { + return false + } + emitter.state = emitter.states[len(emitter.states)-1] + emitter.states = emitter.states[:len(emitter.states)-1] + return true + } + + if !first && !trail { + if !yaml_emitter_write_indicator(emitter, []byte{','}, false, false, false) { + return false + } + } + + if !yaml_emitter_process_head_comment(emitter) { + return false + } + + if emitter.column == 0 { + if !yaml_emitter_write_indent(emitter) { + return false + } + } + + if emitter.canonical || emitter.column > emitter.best_width { + if !yaml_emitter_write_indent(emitter) { + return false + } + } + + if !emitter.canonical && yaml_emitter_check_simple_key(emitter) { + emitter.states = append(emitter.states, yaml_EMIT_FLOW_MAPPING_SIMPLE_VALUE_STATE) + return yaml_emitter_emit_node(emitter, event, false, false, true, true) + } + if !yaml_emitter_write_indicator(emitter, []byte{'?'}, true, false, false) { + return false + } + emitter.states = append(emitter.states, yaml_EMIT_FLOW_MAPPING_VALUE_STATE) + return yaml_emitter_emit_node(emitter, event, false, false, true, false) +} + +// Expect a flow value node. +func yaml_emitter_emit_flow_mapping_value(emitter *yaml_emitter_t, event *yaml_event_t, simple bool) bool { + if simple { + if !yaml_emitter_write_indicator(emitter, []byte{':'}, false, false, false) { + return false + } + } else { + if emitter.canonical || emitter.column > emitter.best_width { + if !yaml_emitter_write_indent(emitter) { + return false + } + } + if !yaml_emitter_write_indicator(emitter, []byte{':'}, true, false, false) { + return false + } + } + if len(emitter.line_comment)+len(emitter.foot_comment)+len(emitter.tail_comment) > 0 { + emitter.states = append(emitter.states, yaml_EMIT_FLOW_MAPPING_TRAIL_KEY_STATE) + } else { + emitter.states = append(emitter.states, yaml_EMIT_FLOW_MAPPING_KEY_STATE) + } + if !yaml_emitter_emit_node(emitter, event, false, false, true, false) { + return false + } + if len(emitter.line_comment)+len(emitter.foot_comment)+len(emitter.tail_comment) > 0 { + if !yaml_emitter_write_indicator(emitter, []byte{','}, false, false, false) { + return false + } + } + if !yaml_emitter_process_line_comment(emitter) { + return false + } + if !yaml_emitter_process_foot_comment(emitter) { + return false + } + return true +} + +// Expect a block item node. +func yaml_emitter_emit_block_sequence_item(emitter *yaml_emitter_t, event *yaml_event_t, first bool) bool { + if first { + if !yaml_emitter_increase_indent(emitter, false, false) { + return false + } + } + if event.typ == yaml_SEQUENCE_END_EVENT { + emitter.indent = emitter.indents[len(emitter.indents)-1] + emitter.indents = emitter.indents[:len(emitter.indents)-1] + emitter.state = emitter.states[len(emitter.states)-1] + emitter.states = emitter.states[:len(emitter.states)-1] + return true + } + if !yaml_emitter_process_head_comment(emitter) { + return false + } + if !yaml_emitter_write_indent(emitter) { + return false + } + if !yaml_emitter_write_indicator(emitter, []byte{'-'}, true, false, true) { + return false + } + emitter.states = append(emitter.states, yaml_EMIT_BLOCK_SEQUENCE_ITEM_STATE) + if !yaml_emitter_emit_node(emitter, event, false, true, false, false) { + return false + } + if !yaml_emitter_process_line_comment(emitter) { + return false + } + if !yaml_emitter_process_foot_comment(emitter) { + return false + } + return true +} + +// Expect a block key node. +func yaml_emitter_emit_block_mapping_key(emitter *yaml_emitter_t, event *yaml_event_t, first bool) bool { + if first { + if !yaml_emitter_increase_indent(emitter, false, false) { + return false + } + } + if !yaml_emitter_process_head_comment(emitter) { + return false + } + if event.typ == yaml_MAPPING_END_EVENT { + emitter.indent = emitter.indents[len(emitter.indents)-1] + emitter.indents = emitter.indents[:len(emitter.indents)-1] + emitter.state = emitter.states[len(emitter.states)-1] + emitter.states = emitter.states[:len(emitter.states)-1] + return true + } + if !yaml_emitter_write_indent(emitter) { + return false + } + if len(emitter.line_comment) > 0 { + // [Go] A line comment was provided for the key. That's unusual as the + // scanner associates line comments with the value. Either way, + // save the line comment and render it appropriately later. + emitter.key_line_comment = emitter.line_comment + emitter.line_comment = nil + } + if yaml_emitter_check_simple_key(emitter) { + emitter.states = append(emitter.states, yaml_EMIT_BLOCK_MAPPING_SIMPLE_VALUE_STATE) + return yaml_emitter_emit_node(emitter, event, false, false, true, true) + } + if !yaml_emitter_write_indicator(emitter, []byte{'?'}, true, false, true) { + return false + } + emitter.states = append(emitter.states, yaml_EMIT_BLOCK_MAPPING_VALUE_STATE) + return yaml_emitter_emit_node(emitter, event, false, false, true, false) +} + +// Expect a block value node. +func yaml_emitter_emit_block_mapping_value(emitter *yaml_emitter_t, event *yaml_event_t, simple bool) bool { + if simple { + if !yaml_emitter_write_indicator(emitter, []byte{':'}, false, false, false) { + return false + } + } else { + if !yaml_emitter_write_indent(emitter) { + return false + } + if !yaml_emitter_write_indicator(emitter, []byte{':'}, true, false, true) { + return false + } + } + if len(emitter.key_line_comment) > 0 { + // [Go] Line comments are generally associated with the value, but when there's + // no value on the same line as a mapping key they end up attached to the + // key itself. + if event.typ == yaml_SCALAR_EVENT { + if len(emitter.line_comment) == 0 { + // A scalar is coming and it has no line comments by itself yet, + // so just let it handle the line comment as usual. If it has a + // line comment, we can't have both so the one from the key is lost. + emitter.line_comment = emitter.key_line_comment + emitter.key_line_comment = nil + } + } else if event.sequence_style() != yaml_FLOW_SEQUENCE_STYLE && (event.typ == yaml_MAPPING_START_EVENT || event.typ == yaml_SEQUENCE_START_EVENT) { + // An indented block follows, so write the comment right now. + emitter.line_comment, emitter.key_line_comment = emitter.key_line_comment, emitter.line_comment + if !yaml_emitter_process_line_comment(emitter) { + return false + } + emitter.line_comment, emitter.key_line_comment = emitter.key_line_comment, emitter.line_comment + } + } + emitter.states = append(emitter.states, yaml_EMIT_BLOCK_MAPPING_KEY_STATE) + if !yaml_emitter_emit_node(emitter, event, false, false, true, false) { + return false + } + if !yaml_emitter_process_line_comment(emitter) { + return false + } + if !yaml_emitter_process_foot_comment(emitter) { + return false + } + return true +} + +func yaml_emitter_silent_nil_event(emitter *yaml_emitter_t, event *yaml_event_t) bool { + return event.typ == yaml_SCALAR_EVENT && event.implicit && !emitter.canonical && len(emitter.scalar_data.value) == 0 +} + +// Expect a node. +func yaml_emitter_emit_node(emitter *yaml_emitter_t, event *yaml_event_t, + root bool, sequence bool, mapping bool, simple_key bool) bool { + + emitter.root_context = root + emitter.sequence_context = sequence + emitter.mapping_context = mapping + emitter.simple_key_context = simple_key + + switch event.typ { + case yaml_ALIAS_EVENT: + return yaml_emitter_emit_alias(emitter, event) + case yaml_SCALAR_EVENT: + return yaml_emitter_emit_scalar(emitter, event) + case yaml_SEQUENCE_START_EVENT: + return yaml_emitter_emit_sequence_start(emitter, event) + case yaml_MAPPING_START_EVENT: + return yaml_emitter_emit_mapping_start(emitter, event) + default: + return yaml_emitter_set_emitter_error(emitter, + fmt.Sprintf("expected SCALAR, SEQUENCE-START, MAPPING-START, or ALIAS, but got %v", event.typ)) + } +} + +// Expect ALIAS. +func yaml_emitter_emit_alias(emitter *yaml_emitter_t, event *yaml_event_t) bool { + if !yaml_emitter_process_anchor(emitter) { + return false + } + emitter.state = emitter.states[len(emitter.states)-1] + emitter.states = emitter.states[:len(emitter.states)-1] + return true +} + +// Expect SCALAR. +func yaml_emitter_emit_scalar(emitter *yaml_emitter_t, event *yaml_event_t) bool { + if !yaml_emitter_select_scalar_style(emitter, event) { + return false + } + if !yaml_emitter_process_anchor(emitter) { + return false + } + if !yaml_emitter_process_tag(emitter) { + return false + } + if !yaml_emitter_increase_indent(emitter, true, false) { + return false + } + if !yaml_emitter_process_scalar(emitter) { + return false + } + emitter.indent = emitter.indents[len(emitter.indents)-1] + emitter.indents = emitter.indents[:len(emitter.indents)-1] + emitter.state = emitter.states[len(emitter.states)-1] + emitter.states = emitter.states[:len(emitter.states)-1] + return true +} + +// Expect SEQUENCE-START. +func yaml_emitter_emit_sequence_start(emitter *yaml_emitter_t, event *yaml_event_t) bool { + if !yaml_emitter_process_anchor(emitter) { + return false + } + if !yaml_emitter_process_tag(emitter) { + return false + } + if emitter.flow_level > 0 || emitter.canonical || event.sequence_style() == yaml_FLOW_SEQUENCE_STYLE || + yaml_emitter_check_empty_sequence(emitter) { + emitter.state = yaml_EMIT_FLOW_SEQUENCE_FIRST_ITEM_STATE + } else { + emitter.state = yaml_EMIT_BLOCK_SEQUENCE_FIRST_ITEM_STATE + } + return true +} + +// Expect MAPPING-START. +func yaml_emitter_emit_mapping_start(emitter *yaml_emitter_t, event *yaml_event_t) bool { + if !yaml_emitter_process_anchor(emitter) { + return false + } + if !yaml_emitter_process_tag(emitter) { + return false + } + if emitter.flow_level > 0 || emitter.canonical || event.mapping_style() == yaml_FLOW_MAPPING_STYLE || + yaml_emitter_check_empty_mapping(emitter) { + emitter.state = yaml_EMIT_FLOW_MAPPING_FIRST_KEY_STATE + } else { + emitter.state = yaml_EMIT_BLOCK_MAPPING_FIRST_KEY_STATE + } + return true +} + +// Check if the document content is an empty scalar. +func yaml_emitter_check_empty_document(emitter *yaml_emitter_t) bool { + return false // [Go] Huh? +} + +// Check if the next events represent an empty sequence. +func yaml_emitter_check_empty_sequence(emitter *yaml_emitter_t) bool { + if len(emitter.events)-emitter.events_head < 2 { + return false + } + return emitter.events[emitter.events_head].typ == yaml_SEQUENCE_START_EVENT && + emitter.events[emitter.events_head+1].typ == yaml_SEQUENCE_END_EVENT +} + +// Check if the next events represent an empty mapping. +func yaml_emitter_check_empty_mapping(emitter *yaml_emitter_t) bool { + if len(emitter.events)-emitter.events_head < 2 { + return false + } + return emitter.events[emitter.events_head].typ == yaml_MAPPING_START_EVENT && + emitter.events[emitter.events_head+1].typ == yaml_MAPPING_END_EVENT +} + +// Check if the next node can be expressed as a simple key. +func yaml_emitter_check_simple_key(emitter *yaml_emitter_t) bool { + length := 0 + switch emitter.events[emitter.events_head].typ { + case yaml_ALIAS_EVENT: + length += len(emitter.anchor_data.anchor) + case yaml_SCALAR_EVENT: + if emitter.scalar_data.multiline { + return false + } + length += len(emitter.anchor_data.anchor) + + len(emitter.tag_data.handle) + + len(emitter.tag_data.suffix) + + len(emitter.scalar_data.value) + case yaml_SEQUENCE_START_EVENT: + if !yaml_emitter_check_empty_sequence(emitter) { + return false + } + length += len(emitter.anchor_data.anchor) + + len(emitter.tag_data.handle) + + len(emitter.tag_data.suffix) + case yaml_MAPPING_START_EVENT: + if !yaml_emitter_check_empty_mapping(emitter) { + return false + } + length += len(emitter.anchor_data.anchor) + + len(emitter.tag_data.handle) + + len(emitter.tag_data.suffix) + default: + return false + } + return length <= 128 +} + +// Determine an acceptable scalar style. +func yaml_emitter_select_scalar_style(emitter *yaml_emitter_t, event *yaml_event_t) bool { + + no_tag := len(emitter.tag_data.handle) == 0 && len(emitter.tag_data.suffix) == 0 + if no_tag && !event.implicit && !event.quoted_implicit { + return yaml_emitter_set_emitter_error(emitter, "neither tag nor implicit flags are specified") + } + + style := event.scalar_style() + if style == yaml_ANY_SCALAR_STYLE { + style = yaml_PLAIN_SCALAR_STYLE + } + if emitter.canonical { + style = yaml_DOUBLE_QUOTED_SCALAR_STYLE + } + if emitter.simple_key_context && emitter.scalar_data.multiline { + style = yaml_DOUBLE_QUOTED_SCALAR_STYLE + } + + if style == yaml_PLAIN_SCALAR_STYLE { + if emitter.flow_level > 0 && !emitter.scalar_data.flow_plain_allowed || + emitter.flow_level == 0 && !emitter.scalar_data.block_plain_allowed { + style = yaml_SINGLE_QUOTED_SCALAR_STYLE + } + if len(emitter.scalar_data.value) == 0 && (emitter.flow_level > 0 || emitter.simple_key_context) { + style = yaml_SINGLE_QUOTED_SCALAR_STYLE + } + if no_tag && !event.implicit { + style = yaml_SINGLE_QUOTED_SCALAR_STYLE + } + } + if style == yaml_SINGLE_QUOTED_SCALAR_STYLE { + if !emitter.scalar_data.single_quoted_allowed { + style = yaml_DOUBLE_QUOTED_SCALAR_STYLE + } + } + if style == yaml_LITERAL_SCALAR_STYLE || style == yaml_FOLDED_SCALAR_STYLE { + if !emitter.scalar_data.block_allowed || emitter.flow_level > 0 || emitter.simple_key_context { + style = yaml_DOUBLE_QUOTED_SCALAR_STYLE + } + } + + if no_tag && !event.quoted_implicit && style != yaml_PLAIN_SCALAR_STYLE { + emitter.tag_data.handle = []byte{'!'} + } + emitter.scalar_data.style = style + return true +} + +// Write an anchor. +func yaml_emitter_process_anchor(emitter *yaml_emitter_t) bool { + if emitter.anchor_data.anchor == nil { + return true + } + c := []byte{'&'} + if emitter.anchor_data.alias { + c[0] = '*' + } + if !yaml_emitter_write_indicator(emitter, c, true, false, false) { + return false + } + return yaml_emitter_write_anchor(emitter, emitter.anchor_data.anchor) +} + +// Write a tag. +func yaml_emitter_process_tag(emitter *yaml_emitter_t) bool { + if len(emitter.tag_data.handle) == 0 && len(emitter.tag_data.suffix) == 0 { + return true + } + if len(emitter.tag_data.handle) > 0 { + if !yaml_emitter_write_tag_handle(emitter, emitter.tag_data.handle) { + return false + } + if len(emitter.tag_data.suffix) > 0 { + if !yaml_emitter_write_tag_content(emitter, emitter.tag_data.suffix, false) { + return false + } + } + } else { + // [Go] Allocate these slices elsewhere. + if !yaml_emitter_write_indicator(emitter, []byte("!<"), true, false, false) { + return false + } + if !yaml_emitter_write_tag_content(emitter, emitter.tag_data.suffix, false) { + return false + } + if !yaml_emitter_write_indicator(emitter, []byte{'>'}, false, false, false) { + return false + } + } + return true +} + +// Write a scalar. +func yaml_emitter_process_scalar(emitter *yaml_emitter_t) bool { + switch emitter.scalar_data.style { + case yaml_PLAIN_SCALAR_STYLE: + return yaml_emitter_write_plain_scalar(emitter, emitter.scalar_data.value, !emitter.simple_key_context) + + case yaml_SINGLE_QUOTED_SCALAR_STYLE: + return yaml_emitter_write_single_quoted_scalar(emitter, emitter.scalar_data.value, !emitter.simple_key_context) + + case yaml_DOUBLE_QUOTED_SCALAR_STYLE: + return yaml_emitter_write_double_quoted_scalar(emitter, emitter.scalar_data.value, !emitter.simple_key_context) + + case yaml_LITERAL_SCALAR_STYLE: + return yaml_emitter_write_literal_scalar(emitter, emitter.scalar_data.value) + + case yaml_FOLDED_SCALAR_STYLE: + return yaml_emitter_write_folded_scalar(emitter, emitter.scalar_data.value) + } + panic("unknown scalar style") +} + +// Write a head comment. +func yaml_emitter_process_head_comment(emitter *yaml_emitter_t) bool { + if len(emitter.tail_comment) > 0 { + if !yaml_emitter_write_indent(emitter) { + return false + } + if !yaml_emitter_write_comment(emitter, emitter.tail_comment) { + return false + } + emitter.tail_comment = emitter.tail_comment[:0] + emitter.foot_indent = emitter.indent + if emitter.foot_indent < 0 { + emitter.foot_indent = 0 + } + } + + if len(emitter.head_comment) == 0 { + return true + } + if !yaml_emitter_write_indent(emitter) { + return false + } + if !yaml_emitter_write_comment(emitter, emitter.head_comment) { + return false + } + emitter.head_comment = emitter.head_comment[:0] + return true +} + +// Write an line comment. +func yaml_emitter_process_line_comment(emitter *yaml_emitter_t) bool { + if len(emitter.line_comment) == 0 { + return true + } + if !emitter.whitespace { + if !put(emitter, ' ') { + return false + } + } + if !yaml_emitter_write_comment(emitter, emitter.line_comment) { + return false + } + emitter.line_comment = emitter.line_comment[:0] + return true +} + +// Write a foot comment. +func yaml_emitter_process_foot_comment(emitter *yaml_emitter_t) bool { + if len(emitter.foot_comment) == 0 { + return true + } + if !yaml_emitter_write_indent(emitter) { + return false + } + if !yaml_emitter_write_comment(emitter, emitter.foot_comment) { + return false + } + emitter.foot_comment = emitter.foot_comment[:0] + emitter.foot_indent = emitter.indent + if emitter.foot_indent < 0 { + emitter.foot_indent = 0 + } + return true +} + +// Check if a %YAML directive is valid. +func yaml_emitter_analyze_version_directive(emitter *yaml_emitter_t, version_directive *yaml_version_directive_t) bool { + if version_directive.major != 1 || version_directive.minor != 1 { + return yaml_emitter_set_emitter_error(emitter, "incompatible %YAML directive") + } + return true +} + +// Check if a %TAG directive is valid. +func yaml_emitter_analyze_tag_directive(emitter *yaml_emitter_t, tag_directive *yaml_tag_directive_t) bool { + handle := tag_directive.handle + prefix := tag_directive.prefix + if len(handle) == 0 { + return yaml_emitter_set_emitter_error(emitter, "tag handle must not be empty") + } + if handle[0] != '!' { + return yaml_emitter_set_emitter_error(emitter, "tag handle must start with '!'") + } + if handle[len(handle)-1] != '!' { + return yaml_emitter_set_emitter_error(emitter, "tag handle must end with '!'") + } + for i := 1; i < len(handle)-1; i += width(handle[i]) { + if !is_alpha(handle, i) { + return yaml_emitter_set_emitter_error(emitter, "tag handle must contain alphanumerical characters only") + } + } + if len(prefix) == 0 { + return yaml_emitter_set_emitter_error(emitter, "tag prefix must not be empty") + } + return true +} + +// Check if an anchor is valid. +func yaml_emitter_analyze_anchor(emitter *yaml_emitter_t, anchor []byte, alias bool) bool { + if len(anchor) == 0 { + problem := "anchor value must not be empty" + if alias { + problem = "alias value must not be empty" + } + return yaml_emitter_set_emitter_error(emitter, problem) + } + for i := 0; i < len(anchor); i += width(anchor[i]) { + if !is_alpha(anchor, i) { + problem := "anchor value must contain alphanumerical characters only" + if alias { + problem = "alias value must contain alphanumerical characters only" + } + return yaml_emitter_set_emitter_error(emitter, problem) + } + } + emitter.anchor_data.anchor = anchor + emitter.anchor_data.alias = alias + return true +} + +// Check if a tag is valid. +func yaml_emitter_analyze_tag(emitter *yaml_emitter_t, tag []byte) bool { + if len(tag) == 0 { + return yaml_emitter_set_emitter_error(emitter, "tag value must not be empty") + } + for i := 0; i < len(emitter.tag_directives); i++ { + tag_directive := &emitter.tag_directives[i] + if bytes.HasPrefix(tag, tag_directive.prefix) { + emitter.tag_data.handle = tag_directive.handle + emitter.tag_data.suffix = tag[len(tag_directive.prefix):] + return true + } + } + emitter.tag_data.suffix = tag + return true +} + +// Check if a scalar is valid. +func yaml_emitter_analyze_scalar(emitter *yaml_emitter_t, value []byte) bool { + var ( + block_indicators = false + flow_indicators = false + line_breaks = false + special_characters = false + tab_characters = false + + leading_space = false + leading_break = false + trailing_space = false + trailing_break = false + break_space = false + space_break = false + + preceded_by_whitespace = false + followed_by_whitespace = false + previous_space = false + previous_break = false + ) + + emitter.scalar_data.value = value + + if len(value) == 0 { + emitter.scalar_data.multiline = false + emitter.scalar_data.flow_plain_allowed = false + emitter.scalar_data.block_plain_allowed = true + emitter.scalar_data.single_quoted_allowed = true + emitter.scalar_data.block_allowed = false + return true + } + + if len(value) >= 3 && ((value[0] == '-' && value[1] == '-' && value[2] == '-') || (value[0] == '.' && value[1] == '.' && value[2] == '.')) { + block_indicators = true + flow_indicators = true + } + + preceded_by_whitespace = true + for i, w := 0, 0; i < len(value); i += w { + w = width(value[i]) + followed_by_whitespace = i+w >= len(value) || is_blank(value, i+w) + + if i == 0 { + switch value[i] { + case '#', ',', '[', ']', '{', '}', '&', '*', '!', '|', '>', '\'', '"', '%', '@', '`': + flow_indicators = true + block_indicators = true + case '?', ':': + flow_indicators = true + if followed_by_whitespace { + block_indicators = true + } + case '-': + if followed_by_whitespace { + flow_indicators = true + block_indicators = true + } + } + } else { + switch value[i] { + case ',', '?', '[', ']', '{', '}': + flow_indicators = true + case ':': + flow_indicators = true + if followed_by_whitespace { + block_indicators = true + } + case '#': + if preceded_by_whitespace { + flow_indicators = true + block_indicators = true + } + } + } + + if value[i] == '\t' { + tab_characters = true + } else if !is_printable(value, i) || !is_ascii(value, i) && !emitter.unicode { + special_characters = true + } + if is_space(value, i) { + if i == 0 { + leading_space = true + } + if i+width(value[i]) == len(value) { + trailing_space = true + } + if previous_break { + break_space = true + } + previous_space = true + previous_break = false + } else if is_break(value, i) { + line_breaks = true + if i == 0 { + leading_break = true + } + if i+width(value[i]) == len(value) { + trailing_break = true + } + if previous_space { + space_break = true + } + previous_space = false + previous_break = true + } else { + previous_space = false + previous_break = false + } + + // [Go]: Why 'z'? Couldn't be the end of the string as that's the loop condition. + preceded_by_whitespace = is_blankz(value, i) + } + + emitter.scalar_data.multiline = line_breaks + emitter.scalar_data.flow_plain_allowed = true + emitter.scalar_data.block_plain_allowed = true + emitter.scalar_data.single_quoted_allowed = true + emitter.scalar_data.block_allowed = true + + if leading_space || leading_break || trailing_space || trailing_break { + emitter.scalar_data.flow_plain_allowed = false + emitter.scalar_data.block_plain_allowed = false + } + if trailing_space { + emitter.scalar_data.block_allowed = false + } + if break_space { + emitter.scalar_data.flow_plain_allowed = false + emitter.scalar_data.block_plain_allowed = false + emitter.scalar_data.single_quoted_allowed = false + } + if space_break || tab_characters || special_characters { + emitter.scalar_data.flow_plain_allowed = false + emitter.scalar_data.block_plain_allowed = false + emitter.scalar_data.single_quoted_allowed = false + } + if space_break || special_characters { + emitter.scalar_data.block_allowed = false + } + if line_breaks { + emitter.scalar_data.flow_plain_allowed = false + emitter.scalar_data.block_plain_allowed = false + } + if flow_indicators { + emitter.scalar_data.flow_plain_allowed = false + } + if block_indicators { + emitter.scalar_data.block_plain_allowed = false + } + return true +} + +// Check if the event data is valid. +func yaml_emitter_analyze_event(emitter *yaml_emitter_t, event *yaml_event_t) bool { + + emitter.anchor_data.anchor = nil + emitter.tag_data.handle = nil + emitter.tag_data.suffix = nil + emitter.scalar_data.value = nil + + if len(event.head_comment) > 0 { + emitter.head_comment = event.head_comment + } + if len(event.line_comment) > 0 { + emitter.line_comment = event.line_comment + } + if len(event.foot_comment) > 0 { + emitter.foot_comment = event.foot_comment + } + if len(event.tail_comment) > 0 { + emitter.tail_comment = event.tail_comment + } + + switch event.typ { + case yaml_ALIAS_EVENT: + if !yaml_emitter_analyze_anchor(emitter, event.anchor, true) { + return false + } + + case yaml_SCALAR_EVENT: + if len(event.anchor) > 0 { + if !yaml_emitter_analyze_anchor(emitter, event.anchor, false) { + return false + } + } + if len(event.tag) > 0 && (emitter.canonical || (!event.implicit && !event.quoted_implicit)) { + if !yaml_emitter_analyze_tag(emitter, event.tag) { + return false + } + } + if !yaml_emitter_analyze_scalar(emitter, event.value) { + return false + } + + case yaml_SEQUENCE_START_EVENT: + if len(event.anchor) > 0 { + if !yaml_emitter_analyze_anchor(emitter, event.anchor, false) { + return false + } + } + if len(event.tag) > 0 && (emitter.canonical || !event.implicit) { + if !yaml_emitter_analyze_tag(emitter, event.tag) { + return false + } + } + + case yaml_MAPPING_START_EVENT: + if len(event.anchor) > 0 { + if !yaml_emitter_analyze_anchor(emitter, event.anchor, false) { + return false + } + } + if len(event.tag) > 0 && (emitter.canonical || !event.implicit) { + if !yaml_emitter_analyze_tag(emitter, event.tag) { + return false + } + } + } + return true +} + +// Write the BOM character. +func yaml_emitter_write_bom(emitter *yaml_emitter_t) bool { + if !flush(emitter) { + return false + } + pos := emitter.buffer_pos + emitter.buffer[pos+0] = '\xEF' + emitter.buffer[pos+1] = '\xBB' + emitter.buffer[pos+2] = '\xBF' + emitter.buffer_pos += 3 + return true +} + +func yaml_emitter_write_indent(emitter *yaml_emitter_t) bool { + indent := emitter.indent + if indent < 0 { + indent = 0 + } + if !emitter.indention || emitter.column > indent || (emitter.column == indent && !emitter.whitespace) { + if !put_break(emitter) { + return false + } + } + if emitter.foot_indent == indent { + if !put_break(emitter) { + return false + } + } + for emitter.column < indent { + if !put(emitter, ' ') { + return false + } + } + emitter.whitespace = true + //emitter.indention = true + emitter.space_above = false + emitter.foot_indent = -1 + return true +} + +func yaml_emitter_write_indicator(emitter *yaml_emitter_t, indicator []byte, need_whitespace, is_whitespace, is_indention bool) bool { + if need_whitespace && !emitter.whitespace { + if !put(emitter, ' ') { + return false + } + } + if !write_all(emitter, indicator) { + return false + } + emitter.whitespace = is_whitespace + emitter.indention = (emitter.indention && is_indention) + emitter.open_ended = false + return true +} + +func yaml_emitter_write_anchor(emitter *yaml_emitter_t, value []byte) bool { + if !write_all(emitter, value) { + return false + } + emitter.whitespace = false + emitter.indention = false + return true +} + +func yaml_emitter_write_tag_handle(emitter *yaml_emitter_t, value []byte) bool { + if !emitter.whitespace { + if !put(emitter, ' ') { + return false + } + } + if !write_all(emitter, value) { + return false + } + emitter.whitespace = false + emitter.indention = false + return true +} + +func yaml_emitter_write_tag_content(emitter *yaml_emitter_t, value []byte, need_whitespace bool) bool { + if need_whitespace && !emitter.whitespace { + if !put(emitter, ' ') { + return false + } + } + for i := 0; i < len(value); { + var must_write bool + switch value[i] { + case ';', '/', '?', ':', '@', '&', '=', '+', '$', ',', '_', '.', '~', '*', '\'', '(', ')', '[', ']': + must_write = true + default: + must_write = is_alpha(value, i) + } + if must_write { + if !write(emitter, value, &i) { + return false + } + } else { + w := width(value[i]) + for k := 0; k < w; k++ { + octet := value[i] + i++ + if !put(emitter, '%') { + return false + } + + c := octet >> 4 + if c < 10 { + c += '0' + } else { + c += 'A' - 10 + } + if !put(emitter, c) { + return false + } + + c = octet & 0x0f + if c < 10 { + c += '0' + } else { + c += 'A' - 10 + } + if !put(emitter, c) { + return false + } + } + } + } + emitter.whitespace = false + emitter.indention = false + return true +} + +func yaml_emitter_write_plain_scalar(emitter *yaml_emitter_t, value []byte, allow_breaks bool) bool { + if len(value) > 0 && !emitter.whitespace { + if !put(emitter, ' ') { + return false + } + } + + spaces := false + breaks := false + for i := 0; i < len(value); { + if is_space(value, i) { + if allow_breaks && !spaces && emitter.column > emitter.best_width && !is_space(value, i+1) { + if !yaml_emitter_write_indent(emitter) { + return false + } + i += width(value[i]) + } else { + if !write(emitter, value, &i) { + return false + } + } + spaces = true + } else if is_break(value, i) { + if !breaks && value[i] == '\n' { + if !put_break(emitter) { + return false + } + } + if !write_break(emitter, value, &i) { + return false + } + //emitter.indention = true + breaks = true + } else { + if breaks { + if !yaml_emitter_write_indent(emitter) { + return false + } + } + if !write(emitter, value, &i) { + return false + } + emitter.indention = false + spaces = false + breaks = false + } + } + + if len(value) > 0 { + emitter.whitespace = false + } + emitter.indention = false + if emitter.root_context { + emitter.open_ended = true + } + + return true +} + +func yaml_emitter_write_single_quoted_scalar(emitter *yaml_emitter_t, value []byte, allow_breaks bool) bool { + + if !yaml_emitter_write_indicator(emitter, []byte{'\''}, true, false, false) { + return false + } + + spaces := false + breaks := false + for i := 0; i < len(value); { + if is_space(value, i) { + if allow_breaks && !spaces && emitter.column > emitter.best_width && i > 0 && i < len(value)-1 && !is_space(value, i+1) { + if !yaml_emitter_write_indent(emitter) { + return false + } + i += width(value[i]) + } else { + if !write(emitter, value, &i) { + return false + } + } + spaces = true + } else if is_break(value, i) { + if !breaks && value[i] == '\n' { + if !put_break(emitter) { + return false + } + } + if !write_break(emitter, value, &i) { + return false + } + //emitter.indention = true + breaks = true + } else { + if breaks { + if !yaml_emitter_write_indent(emitter) { + return false + } + } + if value[i] == '\'' { + if !put(emitter, '\'') { + return false + } + } + if !write(emitter, value, &i) { + return false + } + emitter.indention = false + spaces = false + breaks = false + } + } + if !yaml_emitter_write_indicator(emitter, []byte{'\''}, false, false, false) { + return false + } + emitter.whitespace = false + emitter.indention = false + return true +} + +func yaml_emitter_write_double_quoted_scalar(emitter *yaml_emitter_t, value []byte, allow_breaks bool) bool { + spaces := false + if !yaml_emitter_write_indicator(emitter, []byte{'"'}, true, false, false) { + return false + } + + for i := 0; i < len(value); { + if !is_printable(value, i) || (!emitter.unicode && !is_ascii(value, i)) || + is_bom(value, i) || is_break(value, i) || + value[i] == '"' || value[i] == '\\' { + + octet := value[i] + + var w int + var v rune + switch { + case octet&0x80 == 0x00: + w, v = 1, rune(octet&0x7F) + case octet&0xE0 == 0xC0: + w, v = 2, rune(octet&0x1F) + case octet&0xF0 == 0xE0: + w, v = 3, rune(octet&0x0F) + case octet&0xF8 == 0xF0: + w, v = 4, rune(octet&0x07) + } + for k := 1; k < w; k++ { + octet = value[i+k] + v = (v << 6) + (rune(octet) & 0x3F) + } + i += w + + if !put(emitter, '\\') { + return false + } + + var ok bool + switch v { + case 0x00: + ok = put(emitter, '0') + case 0x07: + ok = put(emitter, 'a') + case 0x08: + ok = put(emitter, 'b') + case 0x09: + ok = put(emitter, 't') + case 0x0A: + ok = put(emitter, 'n') + case 0x0b: + ok = put(emitter, 'v') + case 0x0c: + ok = put(emitter, 'f') + case 0x0d: + ok = put(emitter, 'r') + case 0x1b: + ok = put(emitter, 'e') + case 0x22: + ok = put(emitter, '"') + case 0x5c: + ok = put(emitter, '\\') + case 0x85: + ok = put(emitter, 'N') + case 0xA0: + ok = put(emitter, '_') + case 0x2028: + ok = put(emitter, 'L') + case 0x2029: + ok = put(emitter, 'P') + default: + if v <= 0xFF { + ok = put(emitter, 'x') + w = 2 + } else if v <= 0xFFFF { + ok = put(emitter, 'u') + w = 4 + } else { + ok = put(emitter, 'U') + w = 8 + } + for k := (w - 1) * 4; ok && k >= 0; k -= 4 { + digit := byte((v >> uint(k)) & 0x0F) + if digit < 10 { + ok = put(emitter, digit+'0') + } else { + ok = put(emitter, digit+'A'-10) + } + } + } + if !ok { + return false + } + spaces = false + } else if is_space(value, i) { + if allow_breaks && !spaces && emitter.column > emitter.best_width && i > 0 && i < len(value)-1 { + if !yaml_emitter_write_indent(emitter) { + return false + } + if is_space(value, i+1) { + if !put(emitter, '\\') { + return false + } + } + i += width(value[i]) + } else if !write(emitter, value, &i) { + return false + } + spaces = true + } else { + if !write(emitter, value, &i) { + return false + } + spaces = false + } + } + if !yaml_emitter_write_indicator(emitter, []byte{'"'}, false, false, false) { + return false + } + emitter.whitespace = false + emitter.indention = false + return true +} + +func yaml_emitter_write_block_scalar_hints(emitter *yaml_emitter_t, value []byte) bool { + if is_space(value, 0) || is_break(value, 0) { + indent_hint := []byte{'0' + byte(emitter.best_indent)} + if !yaml_emitter_write_indicator(emitter, indent_hint, false, false, false) { + return false + } + } + + emitter.open_ended = false + + var chomp_hint [1]byte + if len(value) == 0 { + chomp_hint[0] = '-' + } else { + i := len(value) - 1 + for value[i]&0xC0 == 0x80 { + i-- + } + if !is_break(value, i) { + chomp_hint[0] = '-' + } else if i == 0 { + chomp_hint[0] = '+' + emitter.open_ended = true + } else { + i-- + for value[i]&0xC0 == 0x80 { + i-- + } + if is_break(value, i) { + chomp_hint[0] = '+' + emitter.open_ended = true + } + } + } + if chomp_hint[0] != 0 { + if !yaml_emitter_write_indicator(emitter, chomp_hint[:], false, false, false) { + return false + } + } + return true +} + +func yaml_emitter_write_literal_scalar(emitter *yaml_emitter_t, value []byte) bool { + if !yaml_emitter_write_indicator(emitter, []byte{'|'}, true, false, false) { + return false + } + if !yaml_emitter_write_block_scalar_hints(emitter, value) { + return false + } + if !yaml_emitter_process_line_comment(emitter) { + return false + } + //emitter.indention = true + emitter.whitespace = true + breaks := true + for i := 0; i < len(value); { + if is_break(value, i) { + if !write_break(emitter, value, &i) { + return false + } + //emitter.indention = true + breaks = true + } else { + if breaks { + if !yaml_emitter_write_indent(emitter) { + return false + } + } + if !write(emitter, value, &i) { + return false + } + emitter.indention = false + breaks = false + } + } + + return true +} + +func yaml_emitter_write_folded_scalar(emitter *yaml_emitter_t, value []byte) bool { + if !yaml_emitter_write_indicator(emitter, []byte{'>'}, true, false, false) { + return false + } + if !yaml_emitter_write_block_scalar_hints(emitter, value) { + return false + } + if !yaml_emitter_process_line_comment(emitter) { + return false + } + + //emitter.indention = true + emitter.whitespace = true + + breaks := true + leading_spaces := true + for i := 0; i < len(value); { + if is_break(value, i) { + if !breaks && !leading_spaces && value[i] == '\n' { + k := 0 + for is_break(value, k) { + k += width(value[k]) + } + if !is_blankz(value, k) { + if !put_break(emitter) { + return false + } + } + } + if !write_break(emitter, value, &i) { + return false + } + //emitter.indention = true + breaks = true + } else { + if breaks { + if !yaml_emitter_write_indent(emitter) { + return false + } + leading_spaces = is_blank(value, i) + } + if !breaks && is_space(value, i) && !is_space(value, i+1) && emitter.column > emitter.best_width { + if !yaml_emitter_write_indent(emitter) { + return false + } + i += width(value[i]) + } else { + if !write(emitter, value, &i) { + return false + } + } + emitter.indention = false + breaks = false + } + } + return true +} + +func yaml_emitter_write_comment(emitter *yaml_emitter_t, comment []byte) bool { + breaks := false + pound := false + for i := 0; i < len(comment); { + if is_break(comment, i) { + if !write_break(emitter, comment, &i) { + return false + } + //emitter.indention = true + breaks = true + pound = false + } else { + if breaks && !yaml_emitter_write_indent(emitter) { + return false + } + if !pound { + if comment[i] != '#' && (!put(emitter, '#') || !put(emitter, ' ')) { + return false + } + pound = true + } + if !write(emitter, comment, &i) { + return false + } + emitter.indention = false + breaks = false + } + } + if !breaks && !put_break(emitter) { + return false + } + + emitter.whitespace = true + //emitter.indention = true + return true +} diff --git a/vendor/gopkg.in/yaml.v3/encode.go b/vendor/gopkg.in/yaml.v3/encode.go new file mode 100644 index 000000000..de9e72a3e --- /dev/null +++ b/vendor/gopkg.in/yaml.v3/encode.go @@ -0,0 +1,577 @@ +// +// Copyright (c) 2011-2019 Canonical Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package yaml + +import ( + "encoding" + "fmt" + "io" + "reflect" + "regexp" + "sort" + "strconv" + "strings" + "time" + "unicode/utf8" +) + +type encoder struct { + emitter yaml_emitter_t + event yaml_event_t + out []byte + flow bool + indent int + doneInit bool +} + +func newEncoder() *encoder { + e := &encoder{} + yaml_emitter_initialize(&e.emitter) + yaml_emitter_set_output_string(&e.emitter, &e.out) + yaml_emitter_set_unicode(&e.emitter, true) + return e +} + +func newEncoderWithWriter(w io.Writer) *encoder { + e := &encoder{} + yaml_emitter_initialize(&e.emitter) + yaml_emitter_set_output_writer(&e.emitter, w) + yaml_emitter_set_unicode(&e.emitter, true) + return e +} + +func (e *encoder) init() { + if e.doneInit { + return + } + if e.indent == 0 { + e.indent = 4 + } + e.emitter.best_indent = e.indent + yaml_stream_start_event_initialize(&e.event, yaml_UTF8_ENCODING) + e.emit() + e.doneInit = true +} + +func (e *encoder) finish() { + e.emitter.open_ended = false + yaml_stream_end_event_initialize(&e.event) + e.emit() +} + +func (e *encoder) destroy() { + yaml_emitter_delete(&e.emitter) +} + +func (e *encoder) emit() { + // This will internally delete the e.event value. + e.must(yaml_emitter_emit(&e.emitter, &e.event)) +} + +func (e *encoder) must(ok bool) { + if !ok { + msg := e.emitter.problem + if msg == "" { + msg = "unknown problem generating YAML content" + } + failf("%s", msg) + } +} + +func (e *encoder) marshalDoc(tag string, in reflect.Value) { + e.init() + var node *Node + if in.IsValid() { + node, _ = in.Interface().(*Node) + } + if node != nil && node.Kind == DocumentNode { + e.nodev(in) + } else { + yaml_document_start_event_initialize(&e.event, nil, nil, true) + e.emit() + e.marshal(tag, in) + yaml_document_end_event_initialize(&e.event, true) + e.emit() + } +} + +func (e *encoder) marshal(tag string, in reflect.Value) { + tag = shortTag(tag) + if !in.IsValid() || in.Kind() == reflect.Ptr && in.IsNil() { + e.nilv() + return + } + iface := in.Interface() + switch value := iface.(type) { + case *Node: + e.nodev(in) + return + case Node: + if !in.CanAddr() { + var n = reflect.New(in.Type()).Elem() + n.Set(in) + in = n + } + e.nodev(in.Addr()) + return + case time.Time: + e.timev(tag, in) + return + case *time.Time: + e.timev(tag, in.Elem()) + return + case time.Duration: + e.stringv(tag, reflect.ValueOf(value.String())) + return + case Marshaler: + v, err := value.MarshalYAML() + if err != nil { + fail(err) + } + if v == nil { + e.nilv() + return + } + e.marshal(tag, reflect.ValueOf(v)) + return + case encoding.TextMarshaler: + text, err := value.MarshalText() + if err != nil { + fail(err) + } + in = reflect.ValueOf(string(text)) + case nil: + e.nilv() + return + } + switch in.Kind() { + case reflect.Interface: + e.marshal(tag, in.Elem()) + case reflect.Map: + e.mapv(tag, in) + case reflect.Ptr: + e.marshal(tag, in.Elem()) + case reflect.Struct: + e.structv(tag, in) + case reflect.Slice, reflect.Array: + e.slicev(tag, in) + case reflect.String: + e.stringv(tag, in) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + e.intv(tag, in) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + e.uintv(tag, in) + case reflect.Float32, reflect.Float64: + e.floatv(tag, in) + case reflect.Bool: + e.boolv(tag, in) + default: + panic("cannot marshal type: " + in.Type().String()) + } +} + +func (e *encoder) mapv(tag string, in reflect.Value) { + e.mappingv(tag, func() { + keys := keyList(in.MapKeys()) + sort.Sort(keys) + for _, k := range keys { + e.marshal("", k) + e.marshal("", in.MapIndex(k)) + } + }) +} + +func (e *encoder) fieldByIndex(v reflect.Value, index []int) (field reflect.Value) { + for _, num := range index { + for { + if v.Kind() == reflect.Ptr { + if v.IsNil() { + return reflect.Value{} + } + v = v.Elem() + continue + } + break + } + v = v.Field(num) + } + return v +} + +func (e *encoder) structv(tag string, in reflect.Value) { + sinfo, err := getStructInfo(in.Type()) + if err != nil { + panic(err) + } + e.mappingv(tag, func() { + for _, info := range sinfo.FieldsList { + var value reflect.Value + if info.Inline == nil { + value = in.Field(info.Num) + } else { + value = e.fieldByIndex(in, info.Inline) + if !value.IsValid() { + continue + } + } + if info.OmitEmpty && isZero(value) { + continue + } + e.marshal("", reflect.ValueOf(info.Key)) + e.flow = info.Flow + e.marshal("", value) + } + if sinfo.InlineMap >= 0 { + m := in.Field(sinfo.InlineMap) + if m.Len() > 0 { + e.flow = false + keys := keyList(m.MapKeys()) + sort.Sort(keys) + for _, k := range keys { + if _, found := sinfo.FieldsMap[k.String()]; found { + panic(fmt.Sprintf("cannot have key %q in inlined map: conflicts with struct field", k.String())) + } + e.marshal("", k) + e.flow = false + e.marshal("", m.MapIndex(k)) + } + } + } + }) +} + +func (e *encoder) mappingv(tag string, f func()) { + implicit := tag == "" + style := yaml_BLOCK_MAPPING_STYLE + if e.flow { + e.flow = false + style = yaml_FLOW_MAPPING_STYLE + } + yaml_mapping_start_event_initialize(&e.event, nil, []byte(tag), implicit, style) + e.emit() + f() + yaml_mapping_end_event_initialize(&e.event) + e.emit() +} + +func (e *encoder) slicev(tag string, in reflect.Value) { + implicit := tag == "" + style := yaml_BLOCK_SEQUENCE_STYLE + if e.flow { + e.flow = false + style = yaml_FLOW_SEQUENCE_STYLE + } + e.must(yaml_sequence_start_event_initialize(&e.event, nil, []byte(tag), implicit, style)) + e.emit() + n := in.Len() + for i := 0; i < n; i++ { + e.marshal("", in.Index(i)) + } + e.must(yaml_sequence_end_event_initialize(&e.event)) + e.emit() +} + +// isBase60 returns whether s is in base 60 notation as defined in YAML 1.1. +// +// The base 60 float notation in YAML 1.1 is a terrible idea and is unsupported +// in YAML 1.2 and by this package, but these should be marshalled quoted for +// the time being for compatibility with other parsers. +func isBase60Float(s string) (result bool) { + // Fast path. + if s == "" { + return false + } + c := s[0] + if !(c == '+' || c == '-' || c >= '0' && c <= '9') || strings.IndexByte(s, ':') < 0 { + return false + } + // Do the full match. + return base60float.MatchString(s) +} + +// From http://yaml.org/type/float.html, except the regular expression there +// is bogus. In practice parsers do not enforce the "\.[0-9_]*" suffix. +var base60float = regexp.MustCompile(`^[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+(?:\.[0-9_]*)?$`) + +// isOldBool returns whether s is bool notation as defined in YAML 1.1. +// +// We continue to force strings that YAML 1.1 would interpret as booleans to be +// rendered as quotes strings so that the marshalled output valid for YAML 1.1 +// parsing. +func isOldBool(s string) (result bool) { + switch s { + case "y", "Y", "yes", "Yes", "YES", "on", "On", "ON", + "n", "N", "no", "No", "NO", "off", "Off", "OFF": + return true + default: + return false + } +} + +func (e *encoder) stringv(tag string, in reflect.Value) { + var style yaml_scalar_style_t + s := in.String() + canUsePlain := true + switch { + case !utf8.ValidString(s): + if tag == binaryTag { + failf("explicitly tagged !!binary data must be base64-encoded") + } + if tag != "" { + failf("cannot marshal invalid UTF-8 data as %s", shortTag(tag)) + } + // It can't be encoded directly as YAML so use a binary tag + // and encode it as base64. + tag = binaryTag + s = encodeBase64(s) + case tag == "": + // Check to see if it would resolve to a specific + // tag when encoded unquoted. If it doesn't, + // there's no need to quote it. + rtag, _ := resolve("", s) + canUsePlain = rtag == strTag && !(isBase60Float(s) || isOldBool(s)) + } + // Note: it's possible for user code to emit invalid YAML + // if they explicitly specify a tag and a string containing + // text that's incompatible with that tag. + switch { + case strings.Contains(s, "\n"): + if e.flow { + style = yaml_DOUBLE_QUOTED_SCALAR_STYLE + } else { + style = yaml_LITERAL_SCALAR_STYLE + } + case canUsePlain: + style = yaml_PLAIN_SCALAR_STYLE + default: + style = yaml_DOUBLE_QUOTED_SCALAR_STYLE + } + e.emitScalar(s, "", tag, style, nil, nil, nil, nil) +} + +func (e *encoder) boolv(tag string, in reflect.Value) { + var s string + if in.Bool() { + s = "true" + } else { + s = "false" + } + e.emitScalar(s, "", tag, yaml_PLAIN_SCALAR_STYLE, nil, nil, nil, nil) +} + +func (e *encoder) intv(tag string, in reflect.Value) { + s := strconv.FormatInt(in.Int(), 10) + e.emitScalar(s, "", tag, yaml_PLAIN_SCALAR_STYLE, nil, nil, nil, nil) +} + +func (e *encoder) uintv(tag string, in reflect.Value) { + s := strconv.FormatUint(in.Uint(), 10) + e.emitScalar(s, "", tag, yaml_PLAIN_SCALAR_STYLE, nil, nil, nil, nil) +} + +func (e *encoder) timev(tag string, in reflect.Value) { + t := in.Interface().(time.Time) + s := t.Format(time.RFC3339Nano) + e.emitScalar(s, "", tag, yaml_PLAIN_SCALAR_STYLE, nil, nil, nil, nil) +} + +func (e *encoder) floatv(tag string, in reflect.Value) { + // Issue #352: When formatting, use the precision of the underlying value + precision := 64 + if in.Kind() == reflect.Float32 { + precision = 32 + } + + s := strconv.FormatFloat(in.Float(), 'g', -1, precision) + switch s { + case "+Inf": + s = ".inf" + case "-Inf": + s = "-.inf" + case "NaN": + s = ".nan" + } + e.emitScalar(s, "", tag, yaml_PLAIN_SCALAR_STYLE, nil, nil, nil, nil) +} + +func (e *encoder) nilv() { + e.emitScalar("null", "", "", yaml_PLAIN_SCALAR_STYLE, nil, nil, nil, nil) +} + +func (e *encoder) emitScalar(value, anchor, tag string, style yaml_scalar_style_t, head, line, foot, tail []byte) { + // TODO Kill this function. Replace all initialize calls by their underlining Go literals. + implicit := tag == "" + if !implicit { + tag = longTag(tag) + } + e.must(yaml_scalar_event_initialize(&e.event, []byte(anchor), []byte(tag), []byte(value), implicit, implicit, style)) + e.event.head_comment = head + e.event.line_comment = line + e.event.foot_comment = foot + e.event.tail_comment = tail + e.emit() +} + +func (e *encoder) nodev(in reflect.Value) { + e.node(in.Interface().(*Node), "") +} + +func (e *encoder) node(node *Node, tail string) { + // Zero nodes behave as nil. + if node.Kind == 0 && node.IsZero() { + e.nilv() + return + } + + // If the tag was not explicitly requested, and dropping it won't change the + // implicit tag of the value, don't include it in the presentation. + var tag = node.Tag + var stag = shortTag(tag) + var forceQuoting bool + if tag != "" && node.Style&TaggedStyle == 0 { + if node.Kind == ScalarNode { + if stag == strTag && node.Style&(SingleQuotedStyle|DoubleQuotedStyle|LiteralStyle|FoldedStyle) != 0 { + tag = "" + } else { + rtag, _ := resolve("", node.Value) + if rtag == stag { + tag = "" + } else if stag == strTag { + tag = "" + forceQuoting = true + } + } + } else { + var rtag string + switch node.Kind { + case MappingNode: + rtag = mapTag + case SequenceNode: + rtag = seqTag + } + if rtag == stag { + tag = "" + } + } + } + + switch node.Kind { + case DocumentNode: + yaml_document_start_event_initialize(&e.event, nil, nil, true) + e.event.head_comment = []byte(node.HeadComment) + e.emit() + for _, node := range node.Content { + e.node(node, "") + } + yaml_document_end_event_initialize(&e.event, true) + e.event.foot_comment = []byte(node.FootComment) + e.emit() + + case SequenceNode: + style := yaml_BLOCK_SEQUENCE_STYLE + if node.Style&FlowStyle != 0 { + style = yaml_FLOW_SEQUENCE_STYLE + } + e.must(yaml_sequence_start_event_initialize(&e.event, []byte(node.Anchor), []byte(longTag(tag)), tag == "", style)) + e.event.head_comment = []byte(node.HeadComment) + e.emit() + for _, node := range node.Content { + e.node(node, "") + } + e.must(yaml_sequence_end_event_initialize(&e.event)) + e.event.line_comment = []byte(node.LineComment) + e.event.foot_comment = []byte(node.FootComment) + e.emit() + + case MappingNode: + style := yaml_BLOCK_MAPPING_STYLE + if node.Style&FlowStyle != 0 { + style = yaml_FLOW_MAPPING_STYLE + } + yaml_mapping_start_event_initialize(&e.event, []byte(node.Anchor), []byte(longTag(tag)), tag == "", style) + e.event.tail_comment = []byte(tail) + e.event.head_comment = []byte(node.HeadComment) + e.emit() + + // The tail logic below moves the foot comment of prior keys to the following key, + // since the value for each key may be a nested structure and the foot needs to be + // processed only the entirety of the value is streamed. The last tail is processed + // with the mapping end event. + var tail string + for i := 0; i+1 < len(node.Content); i += 2 { + k := node.Content[i] + foot := k.FootComment + if foot != "" { + kopy := *k + kopy.FootComment = "" + k = &kopy + } + e.node(k, tail) + tail = foot + + v := node.Content[i+1] + e.node(v, "") + } + + yaml_mapping_end_event_initialize(&e.event) + e.event.tail_comment = []byte(tail) + e.event.line_comment = []byte(node.LineComment) + e.event.foot_comment = []byte(node.FootComment) + e.emit() + + case AliasNode: + yaml_alias_event_initialize(&e.event, []byte(node.Value)) + e.event.head_comment = []byte(node.HeadComment) + e.event.line_comment = []byte(node.LineComment) + e.event.foot_comment = []byte(node.FootComment) + e.emit() + + case ScalarNode: + value := node.Value + if !utf8.ValidString(value) { + if stag == binaryTag { + failf("explicitly tagged !!binary data must be base64-encoded") + } + if stag != "" { + failf("cannot marshal invalid UTF-8 data as %s", stag) + } + // It can't be encoded directly as YAML so use a binary tag + // and encode it as base64. + tag = binaryTag + value = encodeBase64(value) + } + + style := yaml_PLAIN_SCALAR_STYLE + switch { + case node.Style&DoubleQuotedStyle != 0: + style = yaml_DOUBLE_QUOTED_SCALAR_STYLE + case node.Style&SingleQuotedStyle != 0: + style = yaml_SINGLE_QUOTED_SCALAR_STYLE + case node.Style&LiteralStyle != 0: + style = yaml_LITERAL_SCALAR_STYLE + case node.Style&FoldedStyle != 0: + style = yaml_FOLDED_SCALAR_STYLE + case strings.Contains(value, "\n"): + style = yaml_LITERAL_SCALAR_STYLE + case forceQuoting: + style = yaml_DOUBLE_QUOTED_SCALAR_STYLE + } + + e.emitScalar(value, node.Anchor, tag, style, []byte(node.HeadComment), []byte(node.LineComment), []byte(node.FootComment), []byte(tail)) + default: + failf("cannot encode node with unknown kind %d", node.Kind) + } +} diff --git a/vendor/gopkg.in/yaml.v3/parserc.go b/vendor/gopkg.in/yaml.v3/parserc.go new file mode 100644 index 000000000..268558a0d --- /dev/null +++ b/vendor/gopkg.in/yaml.v3/parserc.go @@ -0,0 +1,1258 @@ +// +// Copyright (c) 2011-2019 Canonical Ltd +// Copyright (c) 2006-2010 Kirill Simonov +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of +// this software and associated documentation files (the "Software"), to deal in +// the Software without restriction, including without limitation the rights to +// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +// of the Software, and to permit persons to whom the Software is furnished to do +// so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package yaml + +import ( + "bytes" +) + +// The parser implements the following grammar: +// +// stream ::= STREAM-START implicit_document? explicit_document* STREAM-END +// implicit_document ::= block_node DOCUMENT-END* +// explicit_document ::= DIRECTIVE* DOCUMENT-START block_node? DOCUMENT-END* +// block_node_or_indentless_sequence ::= +// ALIAS +// | properties (block_content | indentless_block_sequence)? +// | block_content +// | indentless_block_sequence +// block_node ::= ALIAS +// | properties block_content? +// | block_content +// flow_node ::= ALIAS +// | properties flow_content? +// | flow_content +// properties ::= TAG ANCHOR? | ANCHOR TAG? +// block_content ::= block_collection | flow_collection | SCALAR +// flow_content ::= flow_collection | SCALAR +// block_collection ::= block_sequence | block_mapping +// flow_collection ::= flow_sequence | flow_mapping +// block_sequence ::= BLOCK-SEQUENCE-START (BLOCK-ENTRY block_node?)* BLOCK-END +// indentless_sequence ::= (BLOCK-ENTRY block_node?)+ +// block_mapping ::= BLOCK-MAPPING_START +// ((KEY block_node_or_indentless_sequence?)? +// (VALUE block_node_or_indentless_sequence?)?)* +// BLOCK-END +// flow_sequence ::= FLOW-SEQUENCE-START +// (flow_sequence_entry FLOW-ENTRY)* +// flow_sequence_entry? +// FLOW-SEQUENCE-END +// flow_sequence_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)? +// flow_mapping ::= FLOW-MAPPING-START +// (flow_mapping_entry FLOW-ENTRY)* +// flow_mapping_entry? +// FLOW-MAPPING-END +// flow_mapping_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)? + +// Peek the next token in the token queue. +func peek_token(parser *yaml_parser_t) *yaml_token_t { + if parser.token_available || yaml_parser_fetch_more_tokens(parser) { + token := &parser.tokens[parser.tokens_head] + yaml_parser_unfold_comments(parser, token) + return token + } + return nil +} + +// yaml_parser_unfold_comments walks through the comments queue and joins all +// comments behind the position of the provided token into the respective +// top-level comment slices in the parser. +func yaml_parser_unfold_comments(parser *yaml_parser_t, token *yaml_token_t) { + for parser.comments_head < len(parser.comments) && token.start_mark.index >= parser.comments[parser.comments_head].token_mark.index { + comment := &parser.comments[parser.comments_head] + if len(comment.head) > 0 { + if token.typ == yaml_BLOCK_END_TOKEN { + // No heads on ends, so keep comment.head for a follow up token. + break + } + if len(parser.head_comment) > 0 { + parser.head_comment = append(parser.head_comment, '\n') + } + parser.head_comment = append(parser.head_comment, comment.head...) + } + if len(comment.foot) > 0 { + if len(parser.foot_comment) > 0 { + parser.foot_comment = append(parser.foot_comment, '\n') + } + parser.foot_comment = append(parser.foot_comment, comment.foot...) + } + if len(comment.line) > 0 { + if len(parser.line_comment) > 0 { + parser.line_comment = append(parser.line_comment, '\n') + } + parser.line_comment = append(parser.line_comment, comment.line...) + } + *comment = yaml_comment_t{} + parser.comments_head++ + } +} + +// Remove the next token from the queue (must be called after peek_token). +func skip_token(parser *yaml_parser_t) { + parser.token_available = false + parser.tokens_parsed++ + parser.stream_end_produced = parser.tokens[parser.tokens_head].typ == yaml_STREAM_END_TOKEN + parser.tokens_head++ +} + +// Get the next event. +func yaml_parser_parse(parser *yaml_parser_t, event *yaml_event_t) bool { + // Erase the event object. + *event = yaml_event_t{} + + // No events after the end of the stream or error. + if parser.stream_end_produced || parser.error != yaml_NO_ERROR || parser.state == yaml_PARSE_END_STATE { + return true + } + + // Generate the next event. + return yaml_parser_state_machine(parser, event) +} + +// Set parser error. +func yaml_parser_set_parser_error(parser *yaml_parser_t, problem string, problem_mark yaml_mark_t) bool { + parser.error = yaml_PARSER_ERROR + parser.problem = problem + parser.problem_mark = problem_mark + return false +} + +func yaml_parser_set_parser_error_context(parser *yaml_parser_t, context string, context_mark yaml_mark_t, problem string, problem_mark yaml_mark_t) bool { + parser.error = yaml_PARSER_ERROR + parser.context = context + parser.context_mark = context_mark + parser.problem = problem + parser.problem_mark = problem_mark + return false +} + +// State dispatcher. +func yaml_parser_state_machine(parser *yaml_parser_t, event *yaml_event_t) bool { + //trace("yaml_parser_state_machine", "state:", parser.state.String()) + + switch parser.state { + case yaml_PARSE_STREAM_START_STATE: + return yaml_parser_parse_stream_start(parser, event) + + case yaml_PARSE_IMPLICIT_DOCUMENT_START_STATE: + return yaml_parser_parse_document_start(parser, event, true) + + case yaml_PARSE_DOCUMENT_START_STATE: + return yaml_parser_parse_document_start(parser, event, false) + + case yaml_PARSE_DOCUMENT_CONTENT_STATE: + return yaml_parser_parse_document_content(parser, event) + + case yaml_PARSE_DOCUMENT_END_STATE: + return yaml_parser_parse_document_end(parser, event) + + case yaml_PARSE_BLOCK_NODE_STATE: + return yaml_parser_parse_node(parser, event, true, false) + + case yaml_PARSE_BLOCK_NODE_OR_INDENTLESS_SEQUENCE_STATE: + return yaml_parser_parse_node(parser, event, true, true) + + case yaml_PARSE_FLOW_NODE_STATE: + return yaml_parser_parse_node(parser, event, false, false) + + case yaml_PARSE_BLOCK_SEQUENCE_FIRST_ENTRY_STATE: + return yaml_parser_parse_block_sequence_entry(parser, event, true) + + case yaml_PARSE_BLOCK_SEQUENCE_ENTRY_STATE: + return yaml_parser_parse_block_sequence_entry(parser, event, false) + + case yaml_PARSE_INDENTLESS_SEQUENCE_ENTRY_STATE: + return yaml_parser_parse_indentless_sequence_entry(parser, event) + + case yaml_PARSE_BLOCK_MAPPING_FIRST_KEY_STATE: + return yaml_parser_parse_block_mapping_key(parser, event, true) + + case yaml_PARSE_BLOCK_MAPPING_KEY_STATE: + return yaml_parser_parse_block_mapping_key(parser, event, false) + + case yaml_PARSE_BLOCK_MAPPING_VALUE_STATE: + return yaml_parser_parse_block_mapping_value(parser, event) + + case yaml_PARSE_FLOW_SEQUENCE_FIRST_ENTRY_STATE: + return yaml_parser_parse_flow_sequence_entry(parser, event, true) + + case yaml_PARSE_FLOW_SEQUENCE_ENTRY_STATE: + return yaml_parser_parse_flow_sequence_entry(parser, event, false) + + case yaml_PARSE_FLOW_SEQUENCE_ENTRY_MAPPING_KEY_STATE: + return yaml_parser_parse_flow_sequence_entry_mapping_key(parser, event) + + case yaml_PARSE_FLOW_SEQUENCE_ENTRY_MAPPING_VALUE_STATE: + return yaml_parser_parse_flow_sequence_entry_mapping_value(parser, event) + + case yaml_PARSE_FLOW_SEQUENCE_ENTRY_MAPPING_END_STATE: + return yaml_parser_parse_flow_sequence_entry_mapping_end(parser, event) + + case yaml_PARSE_FLOW_MAPPING_FIRST_KEY_STATE: + return yaml_parser_parse_flow_mapping_key(parser, event, true) + + case yaml_PARSE_FLOW_MAPPING_KEY_STATE: + return yaml_parser_parse_flow_mapping_key(parser, event, false) + + case yaml_PARSE_FLOW_MAPPING_VALUE_STATE: + return yaml_parser_parse_flow_mapping_value(parser, event, false) + + case yaml_PARSE_FLOW_MAPPING_EMPTY_VALUE_STATE: + return yaml_parser_parse_flow_mapping_value(parser, event, true) + + default: + panic("invalid parser state") + } +} + +// Parse the production: +// stream ::= STREAM-START implicit_document? explicit_document* STREAM-END +// ************ +func yaml_parser_parse_stream_start(parser *yaml_parser_t, event *yaml_event_t) bool { + token := peek_token(parser) + if token == nil { + return false + } + if token.typ != yaml_STREAM_START_TOKEN { + return yaml_parser_set_parser_error(parser, "did not find expected ", token.start_mark) + } + parser.state = yaml_PARSE_IMPLICIT_DOCUMENT_START_STATE + *event = yaml_event_t{ + typ: yaml_STREAM_START_EVENT, + start_mark: token.start_mark, + end_mark: token.end_mark, + encoding: token.encoding, + } + skip_token(parser) + return true +} + +// Parse the productions: +// implicit_document ::= block_node DOCUMENT-END* +// * +// explicit_document ::= DIRECTIVE* DOCUMENT-START block_node? DOCUMENT-END* +// ************************* +func yaml_parser_parse_document_start(parser *yaml_parser_t, event *yaml_event_t, implicit bool) bool { + + token := peek_token(parser) + if token == nil { + return false + } + + // Parse extra document end indicators. + if !implicit { + for token.typ == yaml_DOCUMENT_END_TOKEN { + skip_token(parser) + token = peek_token(parser) + if token == nil { + return false + } + } + } + + if implicit && token.typ != yaml_VERSION_DIRECTIVE_TOKEN && + token.typ != yaml_TAG_DIRECTIVE_TOKEN && + token.typ != yaml_DOCUMENT_START_TOKEN && + token.typ != yaml_STREAM_END_TOKEN { + // Parse an implicit document. + if !yaml_parser_process_directives(parser, nil, nil) { + return false + } + parser.states = append(parser.states, yaml_PARSE_DOCUMENT_END_STATE) + parser.state = yaml_PARSE_BLOCK_NODE_STATE + + var head_comment []byte + if len(parser.head_comment) > 0 { + // [Go] Scan the header comment backwards, and if an empty line is found, break + // the header so the part before the last empty line goes into the + // document header, while the bottom of it goes into a follow up event. + for i := len(parser.head_comment) - 1; i > 0; i-- { + if parser.head_comment[i] == '\n' { + if i == len(parser.head_comment)-1 { + head_comment = parser.head_comment[:i] + parser.head_comment = parser.head_comment[i+1:] + break + } else if parser.head_comment[i-1] == '\n' { + head_comment = parser.head_comment[:i-1] + parser.head_comment = parser.head_comment[i+1:] + break + } + } + } + } + + *event = yaml_event_t{ + typ: yaml_DOCUMENT_START_EVENT, + start_mark: token.start_mark, + end_mark: token.end_mark, + + head_comment: head_comment, + } + + } else if token.typ != yaml_STREAM_END_TOKEN { + // Parse an explicit document. + var version_directive *yaml_version_directive_t + var tag_directives []yaml_tag_directive_t + start_mark := token.start_mark + if !yaml_parser_process_directives(parser, &version_directive, &tag_directives) { + return false + } + token = peek_token(parser) + if token == nil { + return false + } + if token.typ != yaml_DOCUMENT_START_TOKEN { + yaml_parser_set_parser_error(parser, + "did not find expected ", token.start_mark) + return false + } + parser.states = append(parser.states, yaml_PARSE_DOCUMENT_END_STATE) + parser.state = yaml_PARSE_DOCUMENT_CONTENT_STATE + end_mark := token.end_mark + + *event = yaml_event_t{ + typ: yaml_DOCUMENT_START_EVENT, + start_mark: start_mark, + end_mark: end_mark, + version_directive: version_directive, + tag_directives: tag_directives, + implicit: false, + } + skip_token(parser) + + } else { + // Parse the stream end. + parser.state = yaml_PARSE_END_STATE + *event = yaml_event_t{ + typ: yaml_STREAM_END_EVENT, + start_mark: token.start_mark, + end_mark: token.end_mark, + } + skip_token(parser) + } + + return true +} + +// Parse the productions: +// explicit_document ::= DIRECTIVE* DOCUMENT-START block_node? DOCUMENT-END* +// *********** +// +func yaml_parser_parse_document_content(parser *yaml_parser_t, event *yaml_event_t) bool { + token := peek_token(parser) + if token == nil { + return false + } + + if token.typ == yaml_VERSION_DIRECTIVE_TOKEN || + token.typ == yaml_TAG_DIRECTIVE_TOKEN || + token.typ == yaml_DOCUMENT_START_TOKEN || + token.typ == yaml_DOCUMENT_END_TOKEN || + token.typ == yaml_STREAM_END_TOKEN { + parser.state = parser.states[len(parser.states)-1] + parser.states = parser.states[:len(parser.states)-1] + return yaml_parser_process_empty_scalar(parser, event, + token.start_mark) + } + return yaml_parser_parse_node(parser, event, true, false) +} + +// Parse the productions: +// implicit_document ::= block_node DOCUMENT-END* +// ************* +// explicit_document ::= DIRECTIVE* DOCUMENT-START block_node? DOCUMENT-END* +// +func yaml_parser_parse_document_end(parser *yaml_parser_t, event *yaml_event_t) bool { + token := peek_token(parser) + if token == nil { + return false + } + + start_mark := token.start_mark + end_mark := token.start_mark + + implicit := true + if token.typ == yaml_DOCUMENT_END_TOKEN { + end_mark = token.end_mark + skip_token(parser) + implicit = false + } + + parser.tag_directives = parser.tag_directives[:0] + + parser.state = yaml_PARSE_DOCUMENT_START_STATE + *event = yaml_event_t{ + typ: yaml_DOCUMENT_END_EVENT, + start_mark: start_mark, + end_mark: end_mark, + implicit: implicit, + } + yaml_parser_set_event_comments(parser, event) + if len(event.head_comment) > 0 && len(event.foot_comment) == 0 { + event.foot_comment = event.head_comment + event.head_comment = nil + } + return true +} + +func yaml_parser_set_event_comments(parser *yaml_parser_t, event *yaml_event_t) { + event.head_comment = parser.head_comment + event.line_comment = parser.line_comment + event.foot_comment = parser.foot_comment + parser.head_comment = nil + parser.line_comment = nil + parser.foot_comment = nil + parser.tail_comment = nil + parser.stem_comment = nil +} + +// Parse the productions: +// block_node_or_indentless_sequence ::= +// ALIAS +// ***** +// | properties (block_content | indentless_block_sequence)? +// ********** * +// | block_content | indentless_block_sequence +// * +// block_node ::= ALIAS +// ***** +// | properties block_content? +// ********** * +// | block_content +// * +// flow_node ::= ALIAS +// ***** +// | properties flow_content? +// ********** * +// | flow_content +// * +// properties ::= TAG ANCHOR? | ANCHOR TAG? +// ************************* +// block_content ::= block_collection | flow_collection | SCALAR +// ****** +// flow_content ::= flow_collection | SCALAR +// ****** +func yaml_parser_parse_node(parser *yaml_parser_t, event *yaml_event_t, block, indentless_sequence bool) bool { + //defer trace("yaml_parser_parse_node", "block:", block, "indentless_sequence:", indentless_sequence)() + + token := peek_token(parser) + if token == nil { + return false + } + + if token.typ == yaml_ALIAS_TOKEN { + parser.state = parser.states[len(parser.states)-1] + parser.states = parser.states[:len(parser.states)-1] + *event = yaml_event_t{ + typ: yaml_ALIAS_EVENT, + start_mark: token.start_mark, + end_mark: token.end_mark, + anchor: token.value, + } + yaml_parser_set_event_comments(parser, event) + skip_token(parser) + return true + } + + start_mark := token.start_mark + end_mark := token.start_mark + + var tag_token bool + var tag_handle, tag_suffix, anchor []byte + var tag_mark yaml_mark_t + if token.typ == yaml_ANCHOR_TOKEN { + anchor = token.value + start_mark = token.start_mark + end_mark = token.end_mark + skip_token(parser) + token = peek_token(parser) + if token == nil { + return false + } + if token.typ == yaml_TAG_TOKEN { + tag_token = true + tag_handle = token.value + tag_suffix = token.suffix + tag_mark = token.start_mark + end_mark = token.end_mark + skip_token(parser) + token = peek_token(parser) + if token == nil { + return false + } + } + } else if token.typ == yaml_TAG_TOKEN { + tag_token = true + tag_handle = token.value + tag_suffix = token.suffix + start_mark = token.start_mark + tag_mark = token.start_mark + end_mark = token.end_mark + skip_token(parser) + token = peek_token(parser) + if token == nil { + return false + } + if token.typ == yaml_ANCHOR_TOKEN { + anchor = token.value + end_mark = token.end_mark + skip_token(parser) + token = peek_token(parser) + if token == nil { + return false + } + } + } + + var tag []byte + if tag_token { + if len(tag_handle) == 0 { + tag = tag_suffix + tag_suffix = nil + } else { + for i := range parser.tag_directives { + if bytes.Equal(parser.tag_directives[i].handle, tag_handle) { + tag = append([]byte(nil), parser.tag_directives[i].prefix...) + tag = append(tag, tag_suffix...) + break + } + } + if len(tag) == 0 { + yaml_parser_set_parser_error_context(parser, + "while parsing a node", start_mark, + "found undefined tag handle", tag_mark) + return false + } + } + } + + implicit := len(tag) == 0 + if indentless_sequence && token.typ == yaml_BLOCK_ENTRY_TOKEN { + end_mark = token.end_mark + parser.state = yaml_PARSE_INDENTLESS_SEQUENCE_ENTRY_STATE + *event = yaml_event_t{ + typ: yaml_SEQUENCE_START_EVENT, + start_mark: start_mark, + end_mark: end_mark, + anchor: anchor, + tag: tag, + implicit: implicit, + style: yaml_style_t(yaml_BLOCK_SEQUENCE_STYLE), + } + return true + } + if token.typ == yaml_SCALAR_TOKEN { + var plain_implicit, quoted_implicit bool + end_mark = token.end_mark + if (len(tag) == 0 && token.style == yaml_PLAIN_SCALAR_STYLE) || (len(tag) == 1 && tag[0] == '!') { + plain_implicit = true + } else if len(tag) == 0 { + quoted_implicit = true + } + parser.state = parser.states[len(parser.states)-1] + parser.states = parser.states[:len(parser.states)-1] + + *event = yaml_event_t{ + typ: yaml_SCALAR_EVENT, + start_mark: start_mark, + end_mark: end_mark, + anchor: anchor, + tag: tag, + value: token.value, + implicit: plain_implicit, + quoted_implicit: quoted_implicit, + style: yaml_style_t(token.style), + } + yaml_parser_set_event_comments(parser, event) + skip_token(parser) + return true + } + if token.typ == yaml_FLOW_SEQUENCE_START_TOKEN { + // [Go] Some of the events below can be merged as they differ only on style. + end_mark = token.end_mark + parser.state = yaml_PARSE_FLOW_SEQUENCE_FIRST_ENTRY_STATE + *event = yaml_event_t{ + typ: yaml_SEQUENCE_START_EVENT, + start_mark: start_mark, + end_mark: end_mark, + anchor: anchor, + tag: tag, + implicit: implicit, + style: yaml_style_t(yaml_FLOW_SEQUENCE_STYLE), + } + yaml_parser_set_event_comments(parser, event) + return true + } + if token.typ == yaml_FLOW_MAPPING_START_TOKEN { + end_mark = token.end_mark + parser.state = yaml_PARSE_FLOW_MAPPING_FIRST_KEY_STATE + *event = yaml_event_t{ + typ: yaml_MAPPING_START_EVENT, + start_mark: start_mark, + end_mark: end_mark, + anchor: anchor, + tag: tag, + implicit: implicit, + style: yaml_style_t(yaml_FLOW_MAPPING_STYLE), + } + yaml_parser_set_event_comments(parser, event) + return true + } + if block && token.typ == yaml_BLOCK_SEQUENCE_START_TOKEN { + end_mark = token.end_mark + parser.state = yaml_PARSE_BLOCK_SEQUENCE_FIRST_ENTRY_STATE + *event = yaml_event_t{ + typ: yaml_SEQUENCE_START_EVENT, + start_mark: start_mark, + end_mark: end_mark, + anchor: anchor, + tag: tag, + implicit: implicit, + style: yaml_style_t(yaml_BLOCK_SEQUENCE_STYLE), + } + if parser.stem_comment != nil { + event.head_comment = parser.stem_comment + parser.stem_comment = nil + } + return true + } + if block && token.typ == yaml_BLOCK_MAPPING_START_TOKEN { + end_mark = token.end_mark + parser.state = yaml_PARSE_BLOCK_MAPPING_FIRST_KEY_STATE + *event = yaml_event_t{ + typ: yaml_MAPPING_START_EVENT, + start_mark: start_mark, + end_mark: end_mark, + anchor: anchor, + tag: tag, + implicit: implicit, + style: yaml_style_t(yaml_BLOCK_MAPPING_STYLE), + } + if parser.stem_comment != nil { + event.head_comment = parser.stem_comment + parser.stem_comment = nil + } + return true + } + if len(anchor) > 0 || len(tag) > 0 { + parser.state = parser.states[len(parser.states)-1] + parser.states = parser.states[:len(parser.states)-1] + + *event = yaml_event_t{ + typ: yaml_SCALAR_EVENT, + start_mark: start_mark, + end_mark: end_mark, + anchor: anchor, + tag: tag, + implicit: implicit, + quoted_implicit: false, + style: yaml_style_t(yaml_PLAIN_SCALAR_STYLE), + } + return true + } + + context := "while parsing a flow node" + if block { + context = "while parsing a block node" + } + yaml_parser_set_parser_error_context(parser, context, start_mark, + "did not find expected node content", token.start_mark) + return false +} + +// Parse the productions: +// block_sequence ::= BLOCK-SEQUENCE-START (BLOCK-ENTRY block_node?)* BLOCK-END +// ******************** *********** * ********* +// +func yaml_parser_parse_block_sequence_entry(parser *yaml_parser_t, event *yaml_event_t, first bool) bool { + if first { + token := peek_token(parser) + if token == nil { + return false + } + parser.marks = append(parser.marks, token.start_mark) + skip_token(parser) + } + + token := peek_token(parser) + if token == nil { + return false + } + + if token.typ == yaml_BLOCK_ENTRY_TOKEN { + mark := token.end_mark + prior_head_len := len(parser.head_comment) + skip_token(parser) + yaml_parser_split_stem_comment(parser, prior_head_len) + token = peek_token(parser) + if token == nil { + return false + } + if token.typ != yaml_BLOCK_ENTRY_TOKEN && token.typ != yaml_BLOCK_END_TOKEN { + parser.states = append(parser.states, yaml_PARSE_BLOCK_SEQUENCE_ENTRY_STATE) + return yaml_parser_parse_node(parser, event, true, false) + } else { + parser.state = yaml_PARSE_BLOCK_SEQUENCE_ENTRY_STATE + return yaml_parser_process_empty_scalar(parser, event, mark) + } + } + if token.typ == yaml_BLOCK_END_TOKEN { + parser.state = parser.states[len(parser.states)-1] + parser.states = parser.states[:len(parser.states)-1] + parser.marks = parser.marks[:len(parser.marks)-1] + + *event = yaml_event_t{ + typ: yaml_SEQUENCE_END_EVENT, + start_mark: token.start_mark, + end_mark: token.end_mark, + } + + skip_token(parser) + return true + } + + context_mark := parser.marks[len(parser.marks)-1] + parser.marks = parser.marks[:len(parser.marks)-1] + return yaml_parser_set_parser_error_context(parser, + "while parsing a block collection", context_mark, + "did not find expected '-' indicator", token.start_mark) +} + +// Parse the productions: +// indentless_sequence ::= (BLOCK-ENTRY block_node?)+ +// *********** * +func yaml_parser_parse_indentless_sequence_entry(parser *yaml_parser_t, event *yaml_event_t) bool { + token := peek_token(parser) + if token == nil { + return false + } + + if token.typ == yaml_BLOCK_ENTRY_TOKEN { + mark := token.end_mark + prior_head_len := len(parser.head_comment) + skip_token(parser) + yaml_parser_split_stem_comment(parser, prior_head_len) + token = peek_token(parser) + if token == nil { + return false + } + if token.typ != yaml_BLOCK_ENTRY_TOKEN && + token.typ != yaml_KEY_TOKEN && + token.typ != yaml_VALUE_TOKEN && + token.typ != yaml_BLOCK_END_TOKEN { + parser.states = append(parser.states, yaml_PARSE_INDENTLESS_SEQUENCE_ENTRY_STATE) + return yaml_parser_parse_node(parser, event, true, false) + } + parser.state = yaml_PARSE_INDENTLESS_SEQUENCE_ENTRY_STATE + return yaml_parser_process_empty_scalar(parser, event, mark) + } + parser.state = parser.states[len(parser.states)-1] + parser.states = parser.states[:len(parser.states)-1] + + *event = yaml_event_t{ + typ: yaml_SEQUENCE_END_EVENT, + start_mark: token.start_mark, + end_mark: token.start_mark, // [Go] Shouldn't this be token.end_mark? + } + return true +} + +// Split stem comment from head comment. +// +// When a sequence or map is found under a sequence entry, the former head comment +// is assigned to the underlying sequence or map as a whole, not the individual +// sequence or map entry as would be expected otherwise. To handle this case the +// previous head comment is moved aside as the stem comment. +func yaml_parser_split_stem_comment(parser *yaml_parser_t, stem_len int) { + if stem_len == 0 { + return + } + + token := peek_token(parser) + if token == nil || token.typ != yaml_BLOCK_SEQUENCE_START_TOKEN && token.typ != yaml_BLOCK_MAPPING_START_TOKEN { + return + } + + parser.stem_comment = parser.head_comment[:stem_len] + if len(parser.head_comment) == stem_len { + parser.head_comment = nil + } else { + // Copy suffix to prevent very strange bugs if someone ever appends + // further bytes to the prefix in the stem_comment slice above. + parser.head_comment = append([]byte(nil), parser.head_comment[stem_len+1:]...) + } +} + +// Parse the productions: +// block_mapping ::= BLOCK-MAPPING_START +// ******************* +// ((KEY block_node_or_indentless_sequence?)? +// *** * +// (VALUE block_node_or_indentless_sequence?)?)* +// +// BLOCK-END +// ********* +// +func yaml_parser_parse_block_mapping_key(parser *yaml_parser_t, event *yaml_event_t, first bool) bool { + if first { + token := peek_token(parser) + if token == nil { + return false + } + parser.marks = append(parser.marks, token.start_mark) + skip_token(parser) + } + + token := peek_token(parser) + if token == nil { + return false + } + + // [Go] A tail comment was left from the prior mapping value processed. Emit an event + // as it needs to be processed with that value and not the following key. + if len(parser.tail_comment) > 0 { + *event = yaml_event_t{ + typ: yaml_TAIL_COMMENT_EVENT, + start_mark: token.start_mark, + end_mark: token.end_mark, + foot_comment: parser.tail_comment, + } + parser.tail_comment = nil + return true + } + + if token.typ == yaml_KEY_TOKEN { + mark := token.end_mark + skip_token(parser) + token = peek_token(parser) + if token == nil { + return false + } + if token.typ != yaml_KEY_TOKEN && + token.typ != yaml_VALUE_TOKEN && + token.typ != yaml_BLOCK_END_TOKEN { + parser.states = append(parser.states, yaml_PARSE_BLOCK_MAPPING_VALUE_STATE) + return yaml_parser_parse_node(parser, event, true, true) + } else { + parser.state = yaml_PARSE_BLOCK_MAPPING_VALUE_STATE + return yaml_parser_process_empty_scalar(parser, event, mark) + } + } else if token.typ == yaml_BLOCK_END_TOKEN { + parser.state = parser.states[len(parser.states)-1] + parser.states = parser.states[:len(parser.states)-1] + parser.marks = parser.marks[:len(parser.marks)-1] + *event = yaml_event_t{ + typ: yaml_MAPPING_END_EVENT, + start_mark: token.start_mark, + end_mark: token.end_mark, + } + yaml_parser_set_event_comments(parser, event) + skip_token(parser) + return true + } + + context_mark := parser.marks[len(parser.marks)-1] + parser.marks = parser.marks[:len(parser.marks)-1] + return yaml_parser_set_parser_error_context(parser, + "while parsing a block mapping", context_mark, + "did not find expected key", token.start_mark) +} + +// Parse the productions: +// block_mapping ::= BLOCK-MAPPING_START +// +// ((KEY block_node_or_indentless_sequence?)? +// +// (VALUE block_node_or_indentless_sequence?)?)* +// ***** * +// BLOCK-END +// +// +func yaml_parser_parse_block_mapping_value(parser *yaml_parser_t, event *yaml_event_t) bool { + token := peek_token(parser) + if token == nil { + return false + } + if token.typ == yaml_VALUE_TOKEN { + mark := token.end_mark + skip_token(parser) + token = peek_token(parser) + if token == nil { + return false + } + if token.typ != yaml_KEY_TOKEN && + token.typ != yaml_VALUE_TOKEN && + token.typ != yaml_BLOCK_END_TOKEN { + parser.states = append(parser.states, yaml_PARSE_BLOCK_MAPPING_KEY_STATE) + return yaml_parser_parse_node(parser, event, true, true) + } + parser.state = yaml_PARSE_BLOCK_MAPPING_KEY_STATE + return yaml_parser_process_empty_scalar(parser, event, mark) + } + parser.state = yaml_PARSE_BLOCK_MAPPING_KEY_STATE + return yaml_parser_process_empty_scalar(parser, event, token.start_mark) +} + +// Parse the productions: +// flow_sequence ::= FLOW-SEQUENCE-START +// ******************* +// (flow_sequence_entry FLOW-ENTRY)* +// * ********** +// flow_sequence_entry? +// * +// FLOW-SEQUENCE-END +// ***************** +// flow_sequence_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)? +// * +// +func yaml_parser_parse_flow_sequence_entry(parser *yaml_parser_t, event *yaml_event_t, first bool) bool { + if first { + token := peek_token(parser) + if token == nil { + return false + } + parser.marks = append(parser.marks, token.start_mark) + skip_token(parser) + } + token := peek_token(parser) + if token == nil { + return false + } + if token.typ != yaml_FLOW_SEQUENCE_END_TOKEN { + if !first { + if token.typ == yaml_FLOW_ENTRY_TOKEN { + skip_token(parser) + token = peek_token(parser) + if token == nil { + return false + } + } else { + context_mark := parser.marks[len(parser.marks)-1] + parser.marks = parser.marks[:len(parser.marks)-1] + return yaml_parser_set_parser_error_context(parser, + "while parsing a flow sequence", context_mark, + "did not find expected ',' or ']'", token.start_mark) + } + } + + if token.typ == yaml_KEY_TOKEN { + parser.state = yaml_PARSE_FLOW_SEQUENCE_ENTRY_MAPPING_KEY_STATE + *event = yaml_event_t{ + typ: yaml_MAPPING_START_EVENT, + start_mark: token.start_mark, + end_mark: token.end_mark, + implicit: true, + style: yaml_style_t(yaml_FLOW_MAPPING_STYLE), + } + skip_token(parser) + return true + } else if token.typ != yaml_FLOW_SEQUENCE_END_TOKEN { + parser.states = append(parser.states, yaml_PARSE_FLOW_SEQUENCE_ENTRY_STATE) + return yaml_parser_parse_node(parser, event, false, false) + } + } + + parser.state = parser.states[len(parser.states)-1] + parser.states = parser.states[:len(parser.states)-1] + parser.marks = parser.marks[:len(parser.marks)-1] + + *event = yaml_event_t{ + typ: yaml_SEQUENCE_END_EVENT, + start_mark: token.start_mark, + end_mark: token.end_mark, + } + yaml_parser_set_event_comments(parser, event) + + skip_token(parser) + return true +} + +// +// Parse the productions: +// flow_sequence_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)? +// *** * +// +func yaml_parser_parse_flow_sequence_entry_mapping_key(parser *yaml_parser_t, event *yaml_event_t) bool { + token := peek_token(parser) + if token == nil { + return false + } + if token.typ != yaml_VALUE_TOKEN && + token.typ != yaml_FLOW_ENTRY_TOKEN && + token.typ != yaml_FLOW_SEQUENCE_END_TOKEN { + parser.states = append(parser.states, yaml_PARSE_FLOW_SEQUENCE_ENTRY_MAPPING_VALUE_STATE) + return yaml_parser_parse_node(parser, event, false, false) + } + mark := token.end_mark + skip_token(parser) + parser.state = yaml_PARSE_FLOW_SEQUENCE_ENTRY_MAPPING_VALUE_STATE + return yaml_parser_process_empty_scalar(parser, event, mark) +} + +// Parse the productions: +// flow_sequence_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)? +// ***** * +// +func yaml_parser_parse_flow_sequence_entry_mapping_value(parser *yaml_parser_t, event *yaml_event_t) bool { + token := peek_token(parser) + if token == nil { + return false + } + if token.typ == yaml_VALUE_TOKEN { + skip_token(parser) + token := peek_token(parser) + if token == nil { + return false + } + if token.typ != yaml_FLOW_ENTRY_TOKEN && token.typ != yaml_FLOW_SEQUENCE_END_TOKEN { + parser.states = append(parser.states, yaml_PARSE_FLOW_SEQUENCE_ENTRY_MAPPING_END_STATE) + return yaml_parser_parse_node(parser, event, false, false) + } + } + parser.state = yaml_PARSE_FLOW_SEQUENCE_ENTRY_MAPPING_END_STATE + return yaml_parser_process_empty_scalar(parser, event, token.start_mark) +} + +// Parse the productions: +// flow_sequence_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)? +// * +// +func yaml_parser_parse_flow_sequence_entry_mapping_end(parser *yaml_parser_t, event *yaml_event_t) bool { + token := peek_token(parser) + if token == nil { + return false + } + parser.state = yaml_PARSE_FLOW_SEQUENCE_ENTRY_STATE + *event = yaml_event_t{ + typ: yaml_MAPPING_END_EVENT, + start_mark: token.start_mark, + end_mark: token.start_mark, // [Go] Shouldn't this be end_mark? + } + return true +} + +// Parse the productions: +// flow_mapping ::= FLOW-MAPPING-START +// ****************** +// (flow_mapping_entry FLOW-ENTRY)* +// * ********** +// flow_mapping_entry? +// ****************** +// FLOW-MAPPING-END +// **************** +// flow_mapping_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)? +// * *** * +// +func yaml_parser_parse_flow_mapping_key(parser *yaml_parser_t, event *yaml_event_t, first bool) bool { + if first { + token := peek_token(parser) + parser.marks = append(parser.marks, token.start_mark) + skip_token(parser) + } + + token := peek_token(parser) + if token == nil { + return false + } + + if token.typ != yaml_FLOW_MAPPING_END_TOKEN { + if !first { + if token.typ == yaml_FLOW_ENTRY_TOKEN { + skip_token(parser) + token = peek_token(parser) + if token == nil { + return false + } + } else { + context_mark := parser.marks[len(parser.marks)-1] + parser.marks = parser.marks[:len(parser.marks)-1] + return yaml_parser_set_parser_error_context(parser, + "while parsing a flow mapping", context_mark, + "did not find expected ',' or '}'", token.start_mark) + } + } + + if token.typ == yaml_KEY_TOKEN { + skip_token(parser) + token = peek_token(parser) + if token == nil { + return false + } + if token.typ != yaml_VALUE_TOKEN && + token.typ != yaml_FLOW_ENTRY_TOKEN && + token.typ != yaml_FLOW_MAPPING_END_TOKEN { + parser.states = append(parser.states, yaml_PARSE_FLOW_MAPPING_VALUE_STATE) + return yaml_parser_parse_node(parser, event, false, false) + } else { + parser.state = yaml_PARSE_FLOW_MAPPING_VALUE_STATE + return yaml_parser_process_empty_scalar(parser, event, token.start_mark) + } + } else if token.typ != yaml_FLOW_MAPPING_END_TOKEN { + parser.states = append(parser.states, yaml_PARSE_FLOW_MAPPING_EMPTY_VALUE_STATE) + return yaml_parser_parse_node(parser, event, false, false) + } + } + + parser.state = parser.states[len(parser.states)-1] + parser.states = parser.states[:len(parser.states)-1] + parser.marks = parser.marks[:len(parser.marks)-1] + *event = yaml_event_t{ + typ: yaml_MAPPING_END_EVENT, + start_mark: token.start_mark, + end_mark: token.end_mark, + } + yaml_parser_set_event_comments(parser, event) + skip_token(parser) + return true +} + +// Parse the productions: +// flow_mapping_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)? +// * ***** * +// +func yaml_parser_parse_flow_mapping_value(parser *yaml_parser_t, event *yaml_event_t, empty bool) bool { + token := peek_token(parser) + if token == nil { + return false + } + if empty { + parser.state = yaml_PARSE_FLOW_MAPPING_KEY_STATE + return yaml_parser_process_empty_scalar(parser, event, token.start_mark) + } + if token.typ == yaml_VALUE_TOKEN { + skip_token(parser) + token = peek_token(parser) + if token == nil { + return false + } + if token.typ != yaml_FLOW_ENTRY_TOKEN && token.typ != yaml_FLOW_MAPPING_END_TOKEN { + parser.states = append(parser.states, yaml_PARSE_FLOW_MAPPING_KEY_STATE) + return yaml_parser_parse_node(parser, event, false, false) + } + } + parser.state = yaml_PARSE_FLOW_MAPPING_KEY_STATE + return yaml_parser_process_empty_scalar(parser, event, token.start_mark) +} + +// Generate an empty scalar event. +func yaml_parser_process_empty_scalar(parser *yaml_parser_t, event *yaml_event_t, mark yaml_mark_t) bool { + *event = yaml_event_t{ + typ: yaml_SCALAR_EVENT, + start_mark: mark, + end_mark: mark, + value: nil, // Empty + implicit: true, + style: yaml_style_t(yaml_PLAIN_SCALAR_STYLE), + } + return true +} + +var default_tag_directives = []yaml_tag_directive_t{ + {[]byte("!"), []byte("!")}, + {[]byte("!!"), []byte("tag:yaml.org,2002:")}, +} + +// Parse directives. +func yaml_parser_process_directives(parser *yaml_parser_t, + version_directive_ref **yaml_version_directive_t, + tag_directives_ref *[]yaml_tag_directive_t) bool { + + var version_directive *yaml_version_directive_t + var tag_directives []yaml_tag_directive_t + + token := peek_token(parser) + if token == nil { + return false + } + + for token.typ == yaml_VERSION_DIRECTIVE_TOKEN || token.typ == yaml_TAG_DIRECTIVE_TOKEN { + if token.typ == yaml_VERSION_DIRECTIVE_TOKEN { + if version_directive != nil { + yaml_parser_set_parser_error(parser, + "found duplicate %YAML directive", token.start_mark) + return false + } + if token.major != 1 || token.minor != 1 { + yaml_parser_set_parser_error(parser, + "found incompatible YAML document", token.start_mark) + return false + } + version_directive = &yaml_version_directive_t{ + major: token.major, + minor: token.minor, + } + } else if token.typ == yaml_TAG_DIRECTIVE_TOKEN { + value := yaml_tag_directive_t{ + handle: token.value, + prefix: token.prefix, + } + if !yaml_parser_append_tag_directive(parser, value, false, token.start_mark) { + return false + } + tag_directives = append(tag_directives, value) + } + + skip_token(parser) + token = peek_token(parser) + if token == nil { + return false + } + } + + for i := range default_tag_directives { + if !yaml_parser_append_tag_directive(parser, default_tag_directives[i], true, token.start_mark) { + return false + } + } + + if version_directive_ref != nil { + *version_directive_ref = version_directive + } + if tag_directives_ref != nil { + *tag_directives_ref = tag_directives + } + return true +} + +// Append a tag directive to the directives stack. +func yaml_parser_append_tag_directive(parser *yaml_parser_t, value yaml_tag_directive_t, allow_duplicates bool, mark yaml_mark_t) bool { + for i := range parser.tag_directives { + if bytes.Equal(value.handle, parser.tag_directives[i].handle) { + if allow_duplicates { + return true + } + return yaml_parser_set_parser_error(parser, "found duplicate %TAG directive", mark) + } + } + + // [Go] I suspect the copy is unnecessary. This was likely done + // because there was no way to track ownership of the data. + value_copy := yaml_tag_directive_t{ + handle: make([]byte, len(value.handle)), + prefix: make([]byte, len(value.prefix)), + } + copy(value_copy.handle, value.handle) + copy(value_copy.prefix, value.prefix) + parser.tag_directives = append(parser.tag_directives, value_copy) + return true +} diff --git a/vendor/gopkg.in/yaml.v3/readerc.go b/vendor/gopkg.in/yaml.v3/readerc.go new file mode 100644 index 000000000..b7de0a89c --- /dev/null +++ b/vendor/gopkg.in/yaml.v3/readerc.go @@ -0,0 +1,434 @@ +// +// Copyright (c) 2011-2019 Canonical Ltd +// Copyright (c) 2006-2010 Kirill Simonov +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of +// this software and associated documentation files (the "Software"), to deal in +// the Software without restriction, including without limitation the rights to +// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +// of the Software, and to permit persons to whom the Software is furnished to do +// so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package yaml + +import ( + "io" +) + +// Set the reader error and return 0. +func yaml_parser_set_reader_error(parser *yaml_parser_t, problem string, offset int, value int) bool { + parser.error = yaml_READER_ERROR + parser.problem = problem + parser.problem_offset = offset + parser.problem_value = value + return false +} + +// Byte order marks. +const ( + bom_UTF8 = "\xef\xbb\xbf" + bom_UTF16LE = "\xff\xfe" + bom_UTF16BE = "\xfe\xff" +) + +// Determine the input stream encoding by checking the BOM symbol. If no BOM is +// found, the UTF-8 encoding is assumed. Return 1 on success, 0 on failure. +func yaml_parser_determine_encoding(parser *yaml_parser_t) bool { + // Ensure that we had enough bytes in the raw buffer. + for !parser.eof && len(parser.raw_buffer)-parser.raw_buffer_pos < 3 { + if !yaml_parser_update_raw_buffer(parser) { + return false + } + } + + // Determine the encoding. + buf := parser.raw_buffer + pos := parser.raw_buffer_pos + avail := len(buf) - pos + if avail >= 2 && buf[pos] == bom_UTF16LE[0] && buf[pos+1] == bom_UTF16LE[1] { + parser.encoding = yaml_UTF16LE_ENCODING + parser.raw_buffer_pos += 2 + parser.offset += 2 + } else if avail >= 2 && buf[pos] == bom_UTF16BE[0] && buf[pos+1] == bom_UTF16BE[1] { + parser.encoding = yaml_UTF16BE_ENCODING + parser.raw_buffer_pos += 2 + parser.offset += 2 + } else if avail >= 3 && buf[pos] == bom_UTF8[0] && buf[pos+1] == bom_UTF8[1] && buf[pos+2] == bom_UTF8[2] { + parser.encoding = yaml_UTF8_ENCODING + parser.raw_buffer_pos += 3 + parser.offset += 3 + } else { + parser.encoding = yaml_UTF8_ENCODING + } + return true +} + +// Update the raw buffer. +func yaml_parser_update_raw_buffer(parser *yaml_parser_t) bool { + size_read := 0 + + // Return if the raw buffer is full. + if parser.raw_buffer_pos == 0 && len(parser.raw_buffer) == cap(parser.raw_buffer) { + return true + } + + // Return on EOF. + if parser.eof { + return true + } + + // Move the remaining bytes in the raw buffer to the beginning. + if parser.raw_buffer_pos > 0 && parser.raw_buffer_pos < len(parser.raw_buffer) { + copy(parser.raw_buffer, parser.raw_buffer[parser.raw_buffer_pos:]) + } + parser.raw_buffer = parser.raw_buffer[:len(parser.raw_buffer)-parser.raw_buffer_pos] + parser.raw_buffer_pos = 0 + + // Call the read handler to fill the buffer. + size_read, err := parser.read_handler(parser, parser.raw_buffer[len(parser.raw_buffer):cap(parser.raw_buffer)]) + parser.raw_buffer = parser.raw_buffer[:len(parser.raw_buffer)+size_read] + if err == io.EOF { + parser.eof = true + } else if err != nil { + return yaml_parser_set_reader_error(parser, "input error: "+err.Error(), parser.offset, -1) + } + return true +} + +// Ensure that the buffer contains at least `length` characters. +// Return true on success, false on failure. +// +// The length is supposed to be significantly less that the buffer size. +func yaml_parser_update_buffer(parser *yaml_parser_t, length int) bool { + if parser.read_handler == nil { + panic("read handler must be set") + } + + // [Go] This function was changed to guarantee the requested length size at EOF. + // The fact we need to do this is pretty awful, but the description above implies + // for that to be the case, and there are tests + + // If the EOF flag is set and the raw buffer is empty, do nothing. + if parser.eof && parser.raw_buffer_pos == len(parser.raw_buffer) { + // [Go] ACTUALLY! Read the documentation of this function above. + // This is just broken. To return true, we need to have the + // given length in the buffer. Not doing that means every single + // check that calls this function to make sure the buffer has a + // given length is Go) panicking; or C) accessing invalid memory. + //return true + } + + // Return if the buffer contains enough characters. + if parser.unread >= length { + return true + } + + // Determine the input encoding if it is not known yet. + if parser.encoding == yaml_ANY_ENCODING { + if !yaml_parser_determine_encoding(parser) { + return false + } + } + + // Move the unread characters to the beginning of the buffer. + buffer_len := len(parser.buffer) + if parser.buffer_pos > 0 && parser.buffer_pos < buffer_len { + copy(parser.buffer, parser.buffer[parser.buffer_pos:]) + buffer_len -= parser.buffer_pos + parser.buffer_pos = 0 + } else if parser.buffer_pos == buffer_len { + buffer_len = 0 + parser.buffer_pos = 0 + } + + // Open the whole buffer for writing, and cut it before returning. + parser.buffer = parser.buffer[:cap(parser.buffer)] + + // Fill the buffer until it has enough characters. + first := true + for parser.unread < length { + + // Fill the raw buffer if necessary. + if !first || parser.raw_buffer_pos == len(parser.raw_buffer) { + if !yaml_parser_update_raw_buffer(parser) { + parser.buffer = parser.buffer[:buffer_len] + return false + } + } + first = false + + // Decode the raw buffer. + inner: + for parser.raw_buffer_pos != len(parser.raw_buffer) { + var value rune + var width int + + raw_unread := len(parser.raw_buffer) - parser.raw_buffer_pos + + // Decode the next character. + switch parser.encoding { + case yaml_UTF8_ENCODING: + // Decode a UTF-8 character. Check RFC 3629 + // (http://www.ietf.org/rfc/rfc3629.txt) for more details. + // + // The following table (taken from the RFC) is used for + // decoding. + // + // Char. number range | UTF-8 octet sequence + // (hexadecimal) | (binary) + // --------------------+------------------------------------ + // 0000 0000-0000 007F | 0xxxxxxx + // 0000 0080-0000 07FF | 110xxxxx 10xxxxxx + // 0000 0800-0000 FFFF | 1110xxxx 10xxxxxx 10xxxxxx + // 0001 0000-0010 FFFF | 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx + // + // Additionally, the characters in the range 0xD800-0xDFFF + // are prohibited as they are reserved for use with UTF-16 + // surrogate pairs. + + // Determine the length of the UTF-8 sequence. + octet := parser.raw_buffer[parser.raw_buffer_pos] + switch { + case octet&0x80 == 0x00: + width = 1 + case octet&0xE0 == 0xC0: + width = 2 + case octet&0xF0 == 0xE0: + width = 3 + case octet&0xF8 == 0xF0: + width = 4 + default: + // The leading octet is invalid. + return yaml_parser_set_reader_error(parser, + "invalid leading UTF-8 octet", + parser.offset, int(octet)) + } + + // Check if the raw buffer contains an incomplete character. + if width > raw_unread { + if parser.eof { + return yaml_parser_set_reader_error(parser, + "incomplete UTF-8 octet sequence", + parser.offset, -1) + } + break inner + } + + // Decode the leading octet. + switch { + case octet&0x80 == 0x00: + value = rune(octet & 0x7F) + case octet&0xE0 == 0xC0: + value = rune(octet & 0x1F) + case octet&0xF0 == 0xE0: + value = rune(octet & 0x0F) + case octet&0xF8 == 0xF0: + value = rune(octet & 0x07) + default: + value = 0 + } + + // Check and decode the trailing octets. + for k := 1; k < width; k++ { + octet = parser.raw_buffer[parser.raw_buffer_pos+k] + + // Check if the octet is valid. + if (octet & 0xC0) != 0x80 { + return yaml_parser_set_reader_error(parser, + "invalid trailing UTF-8 octet", + parser.offset+k, int(octet)) + } + + // Decode the octet. + value = (value << 6) + rune(octet&0x3F) + } + + // Check the length of the sequence against the value. + switch { + case width == 1: + case width == 2 && value >= 0x80: + case width == 3 && value >= 0x800: + case width == 4 && value >= 0x10000: + default: + return yaml_parser_set_reader_error(parser, + "invalid length of a UTF-8 sequence", + parser.offset, -1) + } + + // Check the range of the value. + if value >= 0xD800 && value <= 0xDFFF || value > 0x10FFFF { + return yaml_parser_set_reader_error(parser, + "invalid Unicode character", + parser.offset, int(value)) + } + + case yaml_UTF16LE_ENCODING, yaml_UTF16BE_ENCODING: + var low, high int + if parser.encoding == yaml_UTF16LE_ENCODING { + low, high = 0, 1 + } else { + low, high = 1, 0 + } + + // The UTF-16 encoding is not as simple as one might + // naively think. Check RFC 2781 + // (http://www.ietf.org/rfc/rfc2781.txt). + // + // Normally, two subsequent bytes describe a Unicode + // character. However a special technique (called a + // surrogate pair) is used for specifying character + // values larger than 0xFFFF. + // + // A surrogate pair consists of two pseudo-characters: + // high surrogate area (0xD800-0xDBFF) + // low surrogate area (0xDC00-0xDFFF) + // + // The following formulas are used for decoding + // and encoding characters using surrogate pairs: + // + // U = U' + 0x10000 (0x01 00 00 <= U <= 0x10 FF FF) + // U' = yyyyyyyyyyxxxxxxxxxx (0 <= U' <= 0x0F FF FF) + // W1 = 110110yyyyyyyyyy + // W2 = 110111xxxxxxxxxx + // + // where U is the character value, W1 is the high surrogate + // area, W2 is the low surrogate area. + + // Check for incomplete UTF-16 character. + if raw_unread < 2 { + if parser.eof { + return yaml_parser_set_reader_error(parser, + "incomplete UTF-16 character", + parser.offset, -1) + } + break inner + } + + // Get the character. + value = rune(parser.raw_buffer[parser.raw_buffer_pos+low]) + + (rune(parser.raw_buffer[parser.raw_buffer_pos+high]) << 8) + + // Check for unexpected low surrogate area. + if value&0xFC00 == 0xDC00 { + return yaml_parser_set_reader_error(parser, + "unexpected low surrogate area", + parser.offset, int(value)) + } + + // Check for a high surrogate area. + if value&0xFC00 == 0xD800 { + width = 4 + + // Check for incomplete surrogate pair. + if raw_unread < 4 { + if parser.eof { + return yaml_parser_set_reader_error(parser, + "incomplete UTF-16 surrogate pair", + parser.offset, -1) + } + break inner + } + + // Get the next character. + value2 := rune(parser.raw_buffer[parser.raw_buffer_pos+low+2]) + + (rune(parser.raw_buffer[parser.raw_buffer_pos+high+2]) << 8) + + // Check for a low surrogate area. + if value2&0xFC00 != 0xDC00 { + return yaml_parser_set_reader_error(parser, + "expected low surrogate area", + parser.offset+2, int(value2)) + } + + // Generate the value of the surrogate pair. + value = 0x10000 + ((value & 0x3FF) << 10) + (value2 & 0x3FF) + } else { + width = 2 + } + + default: + panic("impossible") + } + + // Check if the character is in the allowed range: + // #x9 | #xA | #xD | [#x20-#x7E] (8 bit) + // | #x85 | [#xA0-#xD7FF] | [#xE000-#xFFFD] (16 bit) + // | [#x10000-#x10FFFF] (32 bit) + switch { + case value == 0x09: + case value == 0x0A: + case value == 0x0D: + case value >= 0x20 && value <= 0x7E: + case value == 0x85: + case value >= 0xA0 && value <= 0xD7FF: + case value >= 0xE000 && value <= 0xFFFD: + case value >= 0x10000 && value <= 0x10FFFF: + default: + return yaml_parser_set_reader_error(parser, + "control characters are not allowed", + parser.offset, int(value)) + } + + // Move the raw pointers. + parser.raw_buffer_pos += width + parser.offset += width + + // Finally put the character into the buffer. + if value <= 0x7F { + // 0000 0000-0000 007F . 0xxxxxxx + parser.buffer[buffer_len+0] = byte(value) + buffer_len += 1 + } else if value <= 0x7FF { + // 0000 0080-0000 07FF . 110xxxxx 10xxxxxx + parser.buffer[buffer_len+0] = byte(0xC0 + (value >> 6)) + parser.buffer[buffer_len+1] = byte(0x80 + (value & 0x3F)) + buffer_len += 2 + } else if value <= 0xFFFF { + // 0000 0800-0000 FFFF . 1110xxxx 10xxxxxx 10xxxxxx + parser.buffer[buffer_len+0] = byte(0xE0 + (value >> 12)) + parser.buffer[buffer_len+1] = byte(0x80 + ((value >> 6) & 0x3F)) + parser.buffer[buffer_len+2] = byte(0x80 + (value & 0x3F)) + buffer_len += 3 + } else { + // 0001 0000-0010 FFFF . 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx + parser.buffer[buffer_len+0] = byte(0xF0 + (value >> 18)) + parser.buffer[buffer_len+1] = byte(0x80 + ((value >> 12) & 0x3F)) + parser.buffer[buffer_len+2] = byte(0x80 + ((value >> 6) & 0x3F)) + parser.buffer[buffer_len+3] = byte(0x80 + (value & 0x3F)) + buffer_len += 4 + } + + parser.unread++ + } + + // On EOF, put NUL into the buffer and return. + if parser.eof { + parser.buffer[buffer_len] = 0 + buffer_len++ + parser.unread++ + break + } + } + // [Go] Read the documentation of this function above. To return true, + // we need to have the given length in the buffer. Not doing that means + // every single check that calls this function to make sure the buffer + // has a given length is Go) panicking; or C) accessing invalid memory. + // This happens here due to the EOF above breaking early. + for buffer_len < length { + parser.buffer[buffer_len] = 0 + buffer_len++ + } + parser.buffer = parser.buffer[:buffer_len] + return true +} diff --git a/vendor/gopkg.in/yaml.v3/resolve.go b/vendor/gopkg.in/yaml.v3/resolve.go new file mode 100644 index 000000000..64ae88805 --- /dev/null +++ b/vendor/gopkg.in/yaml.v3/resolve.go @@ -0,0 +1,326 @@ +// +// Copyright (c) 2011-2019 Canonical Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package yaml + +import ( + "encoding/base64" + "math" + "regexp" + "strconv" + "strings" + "time" +) + +type resolveMapItem struct { + value interface{} + tag string +} + +var resolveTable = make([]byte, 256) +var resolveMap = make(map[string]resolveMapItem) + +func init() { + t := resolveTable + t[int('+')] = 'S' // Sign + t[int('-')] = 'S' + for _, c := range "0123456789" { + t[int(c)] = 'D' // Digit + } + for _, c := range "yYnNtTfFoO~" { + t[int(c)] = 'M' // In map + } + t[int('.')] = '.' // Float (potentially in map) + + var resolveMapList = []struct { + v interface{} + tag string + l []string + }{ + {true, boolTag, []string{"true", "True", "TRUE"}}, + {false, boolTag, []string{"false", "False", "FALSE"}}, + {nil, nullTag, []string{"", "~", "null", "Null", "NULL"}}, + {math.NaN(), floatTag, []string{".nan", ".NaN", ".NAN"}}, + {math.Inf(+1), floatTag, []string{".inf", ".Inf", ".INF"}}, + {math.Inf(+1), floatTag, []string{"+.inf", "+.Inf", "+.INF"}}, + {math.Inf(-1), floatTag, []string{"-.inf", "-.Inf", "-.INF"}}, + {"<<", mergeTag, []string{"<<"}}, + } + + m := resolveMap + for _, item := range resolveMapList { + for _, s := range item.l { + m[s] = resolveMapItem{item.v, item.tag} + } + } +} + +const ( + nullTag = "!!null" + boolTag = "!!bool" + strTag = "!!str" + intTag = "!!int" + floatTag = "!!float" + timestampTag = "!!timestamp" + seqTag = "!!seq" + mapTag = "!!map" + binaryTag = "!!binary" + mergeTag = "!!merge" +) + +var longTags = make(map[string]string) +var shortTags = make(map[string]string) + +func init() { + for _, stag := range []string{nullTag, boolTag, strTag, intTag, floatTag, timestampTag, seqTag, mapTag, binaryTag, mergeTag} { + ltag := longTag(stag) + longTags[stag] = ltag + shortTags[ltag] = stag + } +} + +const longTagPrefix = "tag:yaml.org,2002:" + +func shortTag(tag string) string { + if strings.HasPrefix(tag, longTagPrefix) { + if stag, ok := shortTags[tag]; ok { + return stag + } + return "!!" + tag[len(longTagPrefix):] + } + return tag +} + +func longTag(tag string) string { + if strings.HasPrefix(tag, "!!") { + if ltag, ok := longTags[tag]; ok { + return ltag + } + return longTagPrefix + tag[2:] + } + return tag +} + +func resolvableTag(tag string) bool { + switch tag { + case "", strTag, boolTag, intTag, floatTag, nullTag, timestampTag: + return true + } + return false +} + +var yamlStyleFloat = regexp.MustCompile(`^[-+]?(\.[0-9]+|[0-9]+(\.[0-9]*)?)([eE][-+]?[0-9]+)?$`) + +func resolve(tag string, in string) (rtag string, out interface{}) { + tag = shortTag(tag) + if !resolvableTag(tag) { + return tag, in + } + + defer func() { + switch tag { + case "", rtag, strTag, binaryTag: + return + case floatTag: + if rtag == intTag { + switch v := out.(type) { + case int64: + rtag = floatTag + out = float64(v) + return + case int: + rtag = floatTag + out = float64(v) + return + } + } + } + failf("cannot decode %s `%s` as a %s", shortTag(rtag), in, shortTag(tag)) + }() + + // Any data is accepted as a !!str or !!binary. + // Otherwise, the prefix is enough of a hint about what it might be. + hint := byte('N') + if in != "" { + hint = resolveTable[in[0]] + } + if hint != 0 && tag != strTag && tag != binaryTag { + // Handle things we can lookup in a map. + if item, ok := resolveMap[in]; ok { + return item.tag, item.value + } + + // Base 60 floats are a bad idea, were dropped in YAML 1.2, and + // are purposefully unsupported here. They're still quoted on + // the way out for compatibility with other parser, though. + + switch hint { + case 'M': + // We've already checked the map above. + + case '.': + // Not in the map, so maybe a normal float. + floatv, err := strconv.ParseFloat(in, 64) + if err == nil { + return floatTag, floatv + } + + case 'D', 'S': + // Int, float, or timestamp. + // Only try values as a timestamp if the value is unquoted or there's an explicit + // !!timestamp tag. + if tag == "" || tag == timestampTag { + t, ok := parseTimestamp(in) + if ok { + return timestampTag, t + } + } + + plain := strings.Replace(in, "_", "", -1) + intv, err := strconv.ParseInt(plain, 0, 64) + if err == nil { + if intv == int64(int(intv)) { + return intTag, int(intv) + } else { + return intTag, intv + } + } + uintv, err := strconv.ParseUint(plain, 0, 64) + if err == nil { + return intTag, uintv + } + if yamlStyleFloat.MatchString(plain) { + floatv, err := strconv.ParseFloat(plain, 64) + if err == nil { + return floatTag, floatv + } + } + if strings.HasPrefix(plain, "0b") { + intv, err := strconv.ParseInt(plain[2:], 2, 64) + if err == nil { + if intv == int64(int(intv)) { + return intTag, int(intv) + } else { + return intTag, intv + } + } + uintv, err := strconv.ParseUint(plain[2:], 2, 64) + if err == nil { + return intTag, uintv + } + } else if strings.HasPrefix(plain, "-0b") { + intv, err := strconv.ParseInt("-"+plain[3:], 2, 64) + if err == nil { + if true || intv == int64(int(intv)) { + return intTag, int(intv) + } else { + return intTag, intv + } + } + } + // Octals as introduced in version 1.2 of the spec. + // Octals from the 1.1 spec, spelled as 0777, are still + // decoded by default in v3 as well for compatibility. + // May be dropped in v4 depending on how usage evolves. + if strings.HasPrefix(plain, "0o") { + intv, err := strconv.ParseInt(plain[2:], 8, 64) + if err == nil { + if intv == int64(int(intv)) { + return intTag, int(intv) + } else { + return intTag, intv + } + } + uintv, err := strconv.ParseUint(plain[2:], 8, 64) + if err == nil { + return intTag, uintv + } + } else if strings.HasPrefix(plain, "-0o") { + intv, err := strconv.ParseInt("-"+plain[3:], 8, 64) + if err == nil { + if true || intv == int64(int(intv)) { + return intTag, int(intv) + } else { + return intTag, intv + } + } + } + default: + panic("internal error: missing handler for resolver table: " + string(rune(hint)) + " (with " + in + ")") + } + } + return strTag, in +} + +// encodeBase64 encodes s as base64 that is broken up into multiple lines +// as appropriate for the resulting length. +func encodeBase64(s string) string { + const lineLen = 70 + encLen := base64.StdEncoding.EncodedLen(len(s)) + lines := encLen/lineLen + 1 + buf := make([]byte, encLen*2+lines) + in := buf[0:encLen] + out := buf[encLen:] + base64.StdEncoding.Encode(in, []byte(s)) + k := 0 + for i := 0; i < len(in); i += lineLen { + j := i + lineLen + if j > len(in) { + j = len(in) + } + k += copy(out[k:], in[i:j]) + if lines > 1 { + out[k] = '\n' + k++ + } + } + return string(out[:k]) +} + +// This is a subset of the formats allowed by the regular expression +// defined at http://yaml.org/type/timestamp.html. +var allowedTimestampFormats = []string{ + "2006-1-2T15:4:5.999999999Z07:00", // RCF3339Nano with short date fields. + "2006-1-2t15:4:5.999999999Z07:00", // RFC3339Nano with short date fields and lower-case "t". + "2006-1-2 15:4:5.999999999", // space separated with no time zone + "2006-1-2", // date only + // Notable exception: time.Parse cannot handle: "2001-12-14 21:59:43.10 -5" + // from the set of examples. +} + +// parseTimestamp parses s as a timestamp string and +// returns the timestamp and reports whether it succeeded. +// Timestamp formats are defined at http://yaml.org/type/timestamp.html +func parseTimestamp(s string) (time.Time, bool) { + // TODO write code to check all the formats supported by + // http://yaml.org/type/timestamp.html instead of using time.Parse. + + // Quick check: all date formats start with YYYY-. + i := 0 + for ; i < len(s); i++ { + if c := s[i]; c < '0' || c > '9' { + break + } + } + if i != 4 || i == len(s) || s[i] != '-' { + return time.Time{}, false + } + for _, format := range allowedTimestampFormats { + if t, err := time.Parse(format, s); err == nil { + return t, true + } + } + return time.Time{}, false +} diff --git a/vendor/gopkg.in/yaml.v3/scannerc.go b/vendor/gopkg.in/yaml.v3/scannerc.go new file mode 100644 index 000000000..ca0070108 --- /dev/null +++ b/vendor/gopkg.in/yaml.v3/scannerc.go @@ -0,0 +1,3038 @@ +// +// Copyright (c) 2011-2019 Canonical Ltd +// Copyright (c) 2006-2010 Kirill Simonov +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of +// this software and associated documentation files (the "Software"), to deal in +// the Software without restriction, including without limitation the rights to +// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +// of the Software, and to permit persons to whom the Software is furnished to do +// so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package yaml + +import ( + "bytes" + "fmt" +) + +// Introduction +// ************ +// +// The following notes assume that you are familiar with the YAML specification +// (http://yaml.org/spec/1.2/spec.html). We mostly follow it, although in +// some cases we are less restrictive that it requires. +// +// The process of transforming a YAML stream into a sequence of events is +// divided on two steps: Scanning and Parsing. +// +// The Scanner transforms the input stream into a sequence of tokens, while the +// parser transform the sequence of tokens produced by the Scanner into a +// sequence of parsing events. +// +// The Scanner is rather clever and complicated. The Parser, on the contrary, +// is a straightforward implementation of a recursive-descendant parser (or, +// LL(1) parser, as it is usually called). +// +// Actually there are two issues of Scanning that might be called "clever", the +// rest is quite straightforward. The issues are "block collection start" and +// "simple keys". Both issues are explained below in details. +// +// Here the Scanning step is explained and implemented. We start with the list +// of all the tokens produced by the Scanner together with short descriptions. +// +// Now, tokens: +// +// STREAM-START(encoding) # The stream start. +// STREAM-END # The stream end. +// VERSION-DIRECTIVE(major,minor) # The '%YAML' directive. +// TAG-DIRECTIVE(handle,prefix) # The '%TAG' directive. +// DOCUMENT-START # '---' +// DOCUMENT-END # '...' +// BLOCK-SEQUENCE-START # Indentation increase denoting a block +// BLOCK-MAPPING-START # sequence or a block mapping. +// BLOCK-END # Indentation decrease. +// FLOW-SEQUENCE-START # '[' +// FLOW-SEQUENCE-END # ']' +// BLOCK-SEQUENCE-START # '{' +// BLOCK-SEQUENCE-END # '}' +// BLOCK-ENTRY # '-' +// FLOW-ENTRY # ',' +// KEY # '?' or nothing (simple keys). +// VALUE # ':' +// ALIAS(anchor) # '*anchor' +// ANCHOR(anchor) # '&anchor' +// TAG(handle,suffix) # '!handle!suffix' +// SCALAR(value,style) # A scalar. +// +// The following two tokens are "virtual" tokens denoting the beginning and the +// end of the stream: +// +// STREAM-START(encoding) +// STREAM-END +// +// We pass the information about the input stream encoding with the +// STREAM-START token. +// +// The next two tokens are responsible for tags: +// +// VERSION-DIRECTIVE(major,minor) +// TAG-DIRECTIVE(handle,prefix) +// +// Example: +// +// %YAML 1.1 +// %TAG ! !foo +// %TAG !yaml! tag:yaml.org,2002: +// --- +// +// The correspoding sequence of tokens: +// +// STREAM-START(utf-8) +// VERSION-DIRECTIVE(1,1) +// TAG-DIRECTIVE("!","!foo") +// TAG-DIRECTIVE("!yaml","tag:yaml.org,2002:") +// DOCUMENT-START +// STREAM-END +// +// Note that the VERSION-DIRECTIVE and TAG-DIRECTIVE tokens occupy a whole +// line. +// +// The document start and end indicators are represented by: +// +// DOCUMENT-START +// DOCUMENT-END +// +// Note that if a YAML stream contains an implicit document (without '---' +// and '...' indicators), no DOCUMENT-START and DOCUMENT-END tokens will be +// produced. +// +// In the following examples, we present whole documents together with the +// produced tokens. +// +// 1. An implicit document: +// +// 'a scalar' +// +// Tokens: +// +// STREAM-START(utf-8) +// SCALAR("a scalar",single-quoted) +// STREAM-END +// +// 2. An explicit document: +// +// --- +// 'a scalar' +// ... +// +// Tokens: +// +// STREAM-START(utf-8) +// DOCUMENT-START +// SCALAR("a scalar",single-quoted) +// DOCUMENT-END +// STREAM-END +// +// 3. Several documents in a stream: +// +// 'a scalar' +// --- +// 'another scalar' +// --- +// 'yet another scalar' +// +// Tokens: +// +// STREAM-START(utf-8) +// SCALAR("a scalar",single-quoted) +// DOCUMENT-START +// SCALAR("another scalar",single-quoted) +// DOCUMENT-START +// SCALAR("yet another scalar",single-quoted) +// STREAM-END +// +// We have already introduced the SCALAR token above. The following tokens are +// used to describe aliases, anchors, tag, and scalars: +// +// ALIAS(anchor) +// ANCHOR(anchor) +// TAG(handle,suffix) +// SCALAR(value,style) +// +// The following series of examples illustrate the usage of these tokens: +// +// 1. A recursive sequence: +// +// &A [ *A ] +// +// Tokens: +// +// STREAM-START(utf-8) +// ANCHOR("A") +// FLOW-SEQUENCE-START +// ALIAS("A") +// FLOW-SEQUENCE-END +// STREAM-END +// +// 2. A tagged scalar: +// +// !!float "3.14" # A good approximation. +// +// Tokens: +// +// STREAM-START(utf-8) +// TAG("!!","float") +// SCALAR("3.14",double-quoted) +// STREAM-END +// +// 3. Various scalar styles: +// +// --- # Implicit empty plain scalars do not produce tokens. +// --- a plain scalar +// --- 'a single-quoted scalar' +// --- "a double-quoted scalar" +// --- |- +// a literal scalar +// --- >- +// a folded +// scalar +// +// Tokens: +// +// STREAM-START(utf-8) +// DOCUMENT-START +// DOCUMENT-START +// SCALAR("a plain scalar",plain) +// DOCUMENT-START +// SCALAR("a single-quoted scalar",single-quoted) +// DOCUMENT-START +// SCALAR("a double-quoted scalar",double-quoted) +// DOCUMENT-START +// SCALAR("a literal scalar",literal) +// DOCUMENT-START +// SCALAR("a folded scalar",folded) +// STREAM-END +// +// Now it's time to review collection-related tokens. We will start with +// flow collections: +// +// FLOW-SEQUENCE-START +// FLOW-SEQUENCE-END +// FLOW-MAPPING-START +// FLOW-MAPPING-END +// FLOW-ENTRY +// KEY +// VALUE +// +// The tokens FLOW-SEQUENCE-START, FLOW-SEQUENCE-END, FLOW-MAPPING-START, and +// FLOW-MAPPING-END represent the indicators '[', ']', '{', and '}' +// correspondingly. FLOW-ENTRY represent the ',' indicator. Finally the +// indicators '?' and ':', which are used for denoting mapping keys and values, +// are represented by the KEY and VALUE tokens. +// +// The following examples show flow collections: +// +// 1. A flow sequence: +// +// [item 1, item 2, item 3] +// +// Tokens: +// +// STREAM-START(utf-8) +// FLOW-SEQUENCE-START +// SCALAR("item 1",plain) +// FLOW-ENTRY +// SCALAR("item 2",plain) +// FLOW-ENTRY +// SCALAR("item 3",plain) +// FLOW-SEQUENCE-END +// STREAM-END +// +// 2. A flow mapping: +// +// { +// a simple key: a value, # Note that the KEY token is produced. +// ? a complex key: another value, +// } +// +// Tokens: +// +// STREAM-START(utf-8) +// FLOW-MAPPING-START +// KEY +// SCALAR("a simple key",plain) +// VALUE +// SCALAR("a value",plain) +// FLOW-ENTRY +// KEY +// SCALAR("a complex key",plain) +// VALUE +// SCALAR("another value",plain) +// FLOW-ENTRY +// FLOW-MAPPING-END +// STREAM-END +// +// A simple key is a key which is not denoted by the '?' indicator. Note that +// the Scanner still produce the KEY token whenever it encounters a simple key. +// +// For scanning block collections, the following tokens are used (note that we +// repeat KEY and VALUE here): +// +// BLOCK-SEQUENCE-START +// BLOCK-MAPPING-START +// BLOCK-END +// BLOCK-ENTRY +// KEY +// VALUE +// +// The tokens BLOCK-SEQUENCE-START and BLOCK-MAPPING-START denote indentation +// increase that precedes a block collection (cf. the INDENT token in Python). +// The token BLOCK-END denote indentation decrease that ends a block collection +// (cf. the DEDENT token in Python). However YAML has some syntax pecularities +// that makes detections of these tokens more complex. +// +// The tokens BLOCK-ENTRY, KEY, and VALUE are used to represent the indicators +// '-', '?', and ':' correspondingly. +// +// The following examples show how the tokens BLOCK-SEQUENCE-START, +// BLOCK-MAPPING-START, and BLOCK-END are emitted by the Scanner: +// +// 1. Block sequences: +// +// - item 1 +// - item 2 +// - +// - item 3.1 +// - item 3.2 +// - +// key 1: value 1 +// key 2: value 2 +// +// Tokens: +// +// STREAM-START(utf-8) +// BLOCK-SEQUENCE-START +// BLOCK-ENTRY +// SCALAR("item 1",plain) +// BLOCK-ENTRY +// SCALAR("item 2",plain) +// BLOCK-ENTRY +// BLOCK-SEQUENCE-START +// BLOCK-ENTRY +// SCALAR("item 3.1",plain) +// BLOCK-ENTRY +// SCALAR("item 3.2",plain) +// BLOCK-END +// BLOCK-ENTRY +// BLOCK-MAPPING-START +// KEY +// SCALAR("key 1",plain) +// VALUE +// SCALAR("value 1",plain) +// KEY +// SCALAR("key 2",plain) +// VALUE +// SCALAR("value 2",plain) +// BLOCK-END +// BLOCK-END +// STREAM-END +// +// 2. Block mappings: +// +// a simple key: a value # The KEY token is produced here. +// ? a complex key +// : another value +// a mapping: +// key 1: value 1 +// key 2: value 2 +// a sequence: +// - item 1 +// - item 2 +// +// Tokens: +// +// STREAM-START(utf-8) +// BLOCK-MAPPING-START +// KEY +// SCALAR("a simple key",plain) +// VALUE +// SCALAR("a value",plain) +// KEY +// SCALAR("a complex key",plain) +// VALUE +// SCALAR("another value",plain) +// KEY +// SCALAR("a mapping",plain) +// BLOCK-MAPPING-START +// KEY +// SCALAR("key 1",plain) +// VALUE +// SCALAR("value 1",plain) +// KEY +// SCALAR("key 2",plain) +// VALUE +// SCALAR("value 2",plain) +// BLOCK-END +// KEY +// SCALAR("a sequence",plain) +// VALUE +// BLOCK-SEQUENCE-START +// BLOCK-ENTRY +// SCALAR("item 1",plain) +// BLOCK-ENTRY +// SCALAR("item 2",plain) +// BLOCK-END +// BLOCK-END +// STREAM-END +// +// YAML does not always require to start a new block collection from a new +// line. If the current line contains only '-', '?', and ':' indicators, a new +// block collection may start at the current line. The following examples +// illustrate this case: +// +// 1. Collections in a sequence: +// +// - - item 1 +// - item 2 +// - key 1: value 1 +// key 2: value 2 +// - ? complex key +// : complex value +// +// Tokens: +// +// STREAM-START(utf-8) +// BLOCK-SEQUENCE-START +// BLOCK-ENTRY +// BLOCK-SEQUENCE-START +// BLOCK-ENTRY +// SCALAR("item 1",plain) +// BLOCK-ENTRY +// SCALAR("item 2",plain) +// BLOCK-END +// BLOCK-ENTRY +// BLOCK-MAPPING-START +// KEY +// SCALAR("key 1",plain) +// VALUE +// SCALAR("value 1",plain) +// KEY +// SCALAR("key 2",plain) +// VALUE +// SCALAR("value 2",plain) +// BLOCK-END +// BLOCK-ENTRY +// BLOCK-MAPPING-START +// KEY +// SCALAR("complex key") +// VALUE +// SCALAR("complex value") +// BLOCK-END +// BLOCK-END +// STREAM-END +// +// 2. Collections in a mapping: +// +// ? a sequence +// : - item 1 +// - item 2 +// ? a mapping +// : key 1: value 1 +// key 2: value 2 +// +// Tokens: +// +// STREAM-START(utf-8) +// BLOCK-MAPPING-START +// KEY +// SCALAR("a sequence",plain) +// VALUE +// BLOCK-SEQUENCE-START +// BLOCK-ENTRY +// SCALAR("item 1",plain) +// BLOCK-ENTRY +// SCALAR("item 2",plain) +// BLOCK-END +// KEY +// SCALAR("a mapping",plain) +// VALUE +// BLOCK-MAPPING-START +// KEY +// SCALAR("key 1",plain) +// VALUE +// SCALAR("value 1",plain) +// KEY +// SCALAR("key 2",plain) +// VALUE +// SCALAR("value 2",plain) +// BLOCK-END +// BLOCK-END +// STREAM-END +// +// YAML also permits non-indented sequences if they are included into a block +// mapping. In this case, the token BLOCK-SEQUENCE-START is not produced: +// +// key: +// - item 1 # BLOCK-SEQUENCE-START is NOT produced here. +// - item 2 +// +// Tokens: +// +// STREAM-START(utf-8) +// BLOCK-MAPPING-START +// KEY +// SCALAR("key",plain) +// VALUE +// BLOCK-ENTRY +// SCALAR("item 1",plain) +// BLOCK-ENTRY +// SCALAR("item 2",plain) +// BLOCK-END +// + +// Ensure that the buffer contains the required number of characters. +// Return true on success, false on failure (reader error or memory error). +func cache(parser *yaml_parser_t, length int) bool { + // [Go] This was inlined: !cache(A, B) -> unread < B && !update(A, B) + return parser.unread >= length || yaml_parser_update_buffer(parser, length) +} + +// Advance the buffer pointer. +func skip(parser *yaml_parser_t) { + if !is_blank(parser.buffer, parser.buffer_pos) { + parser.newlines = 0 + } + parser.mark.index++ + parser.mark.column++ + parser.unread-- + parser.buffer_pos += width(parser.buffer[parser.buffer_pos]) +} + +func skip_line(parser *yaml_parser_t) { + if is_crlf(parser.buffer, parser.buffer_pos) { + parser.mark.index += 2 + parser.mark.column = 0 + parser.mark.line++ + parser.unread -= 2 + parser.buffer_pos += 2 + parser.newlines++ + } else if is_break(parser.buffer, parser.buffer_pos) { + parser.mark.index++ + parser.mark.column = 0 + parser.mark.line++ + parser.unread-- + parser.buffer_pos += width(parser.buffer[parser.buffer_pos]) + parser.newlines++ + } +} + +// Copy a character to a string buffer and advance pointers. +func read(parser *yaml_parser_t, s []byte) []byte { + if !is_blank(parser.buffer, parser.buffer_pos) { + parser.newlines = 0 + } + w := width(parser.buffer[parser.buffer_pos]) + if w == 0 { + panic("invalid character sequence") + } + if len(s) == 0 { + s = make([]byte, 0, 32) + } + if w == 1 && len(s)+w <= cap(s) { + s = s[:len(s)+1] + s[len(s)-1] = parser.buffer[parser.buffer_pos] + parser.buffer_pos++ + } else { + s = append(s, parser.buffer[parser.buffer_pos:parser.buffer_pos+w]...) + parser.buffer_pos += w + } + parser.mark.index++ + parser.mark.column++ + parser.unread-- + return s +} + +// Copy a line break character to a string buffer and advance pointers. +func read_line(parser *yaml_parser_t, s []byte) []byte { + buf := parser.buffer + pos := parser.buffer_pos + switch { + case buf[pos] == '\r' && buf[pos+1] == '\n': + // CR LF . LF + s = append(s, '\n') + parser.buffer_pos += 2 + parser.mark.index++ + parser.unread-- + case buf[pos] == '\r' || buf[pos] == '\n': + // CR|LF . LF + s = append(s, '\n') + parser.buffer_pos += 1 + case buf[pos] == '\xC2' && buf[pos+1] == '\x85': + // NEL . LF + s = append(s, '\n') + parser.buffer_pos += 2 + case buf[pos] == '\xE2' && buf[pos+1] == '\x80' && (buf[pos+2] == '\xA8' || buf[pos+2] == '\xA9'): + // LS|PS . LS|PS + s = append(s, buf[parser.buffer_pos:pos+3]...) + parser.buffer_pos += 3 + default: + return s + } + parser.mark.index++ + parser.mark.column = 0 + parser.mark.line++ + parser.unread-- + parser.newlines++ + return s +} + +// Get the next token. +func yaml_parser_scan(parser *yaml_parser_t, token *yaml_token_t) bool { + // Erase the token object. + *token = yaml_token_t{} // [Go] Is this necessary? + + // No tokens after STREAM-END or error. + if parser.stream_end_produced || parser.error != yaml_NO_ERROR { + return true + } + + // Ensure that the tokens queue contains enough tokens. + if !parser.token_available { + if !yaml_parser_fetch_more_tokens(parser) { + return false + } + } + + // Fetch the next token from the queue. + *token = parser.tokens[parser.tokens_head] + parser.tokens_head++ + parser.tokens_parsed++ + parser.token_available = false + + if token.typ == yaml_STREAM_END_TOKEN { + parser.stream_end_produced = true + } + return true +} + +// Set the scanner error and return false. +func yaml_parser_set_scanner_error(parser *yaml_parser_t, context string, context_mark yaml_mark_t, problem string) bool { + parser.error = yaml_SCANNER_ERROR + parser.context = context + parser.context_mark = context_mark + parser.problem = problem + parser.problem_mark = parser.mark + return false +} + +func yaml_parser_set_scanner_tag_error(parser *yaml_parser_t, directive bool, context_mark yaml_mark_t, problem string) bool { + context := "while parsing a tag" + if directive { + context = "while parsing a %TAG directive" + } + return yaml_parser_set_scanner_error(parser, context, context_mark, problem) +} + +func trace(args ...interface{}) func() { + pargs := append([]interface{}{"+++"}, args...) + fmt.Println(pargs...) + pargs = append([]interface{}{"---"}, args...) + return func() { fmt.Println(pargs...) } +} + +// Ensure that the tokens queue contains at least one token which can be +// returned to the Parser. +func yaml_parser_fetch_more_tokens(parser *yaml_parser_t) bool { + // While we need more tokens to fetch, do it. + for { + // [Go] The comment parsing logic requires a lookahead of two tokens + // so that foot comments may be parsed in time of associating them + // with the tokens that are parsed before them, and also for line + // comments to be transformed into head comments in some edge cases. + if parser.tokens_head < len(parser.tokens)-2 { + // If a potential simple key is at the head position, we need to fetch + // the next token to disambiguate it. + head_tok_idx, ok := parser.simple_keys_by_tok[parser.tokens_parsed] + if !ok { + break + } else if valid, ok := yaml_simple_key_is_valid(parser, &parser.simple_keys[head_tok_idx]); !ok { + return false + } else if !valid { + break + } + } + // Fetch the next token. + if !yaml_parser_fetch_next_token(parser) { + return false + } + } + + parser.token_available = true + return true +} + +// The dispatcher for token fetchers. +func yaml_parser_fetch_next_token(parser *yaml_parser_t) (ok bool) { + // Ensure that the buffer is initialized. + if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { + return false + } + + // Check if we just started scanning. Fetch STREAM-START then. + if !parser.stream_start_produced { + return yaml_parser_fetch_stream_start(parser) + } + + scan_mark := parser.mark + + // Eat whitespaces and comments until we reach the next token. + if !yaml_parser_scan_to_next_token(parser) { + return false + } + + // [Go] While unrolling indents, transform the head comments of prior + // indentation levels observed after scan_start into foot comments at + // the respective indexes. + + // Check the indentation level against the current column. + if !yaml_parser_unroll_indent(parser, parser.mark.column, scan_mark) { + return false + } + + // Ensure that the buffer contains at least 4 characters. 4 is the length + // of the longest indicators ('--- ' and '... '). + if parser.unread < 4 && !yaml_parser_update_buffer(parser, 4) { + return false + } + + // Is it the end of the stream? + if is_z(parser.buffer, parser.buffer_pos) { + return yaml_parser_fetch_stream_end(parser) + } + + // Is it a directive? + if parser.mark.column == 0 && parser.buffer[parser.buffer_pos] == '%' { + return yaml_parser_fetch_directive(parser) + } + + buf := parser.buffer + pos := parser.buffer_pos + + // Is it the document start indicator? + if parser.mark.column == 0 && buf[pos] == '-' && buf[pos+1] == '-' && buf[pos+2] == '-' && is_blankz(buf, pos+3) { + return yaml_parser_fetch_document_indicator(parser, yaml_DOCUMENT_START_TOKEN) + } + + // Is it the document end indicator? + if parser.mark.column == 0 && buf[pos] == '.' && buf[pos+1] == '.' && buf[pos+2] == '.' && is_blankz(buf, pos+3) { + return yaml_parser_fetch_document_indicator(parser, yaml_DOCUMENT_END_TOKEN) + } + + comment_mark := parser.mark + if len(parser.tokens) > 0 && (parser.flow_level == 0 && buf[pos] == ':' || parser.flow_level > 0 && buf[pos] == ',') { + // Associate any following comments with the prior token. + comment_mark = parser.tokens[len(parser.tokens)-1].start_mark + } + defer func() { + if !ok { + return + } + if len(parser.tokens) > 0 && parser.tokens[len(parser.tokens)-1].typ == yaml_BLOCK_ENTRY_TOKEN { + // Sequence indicators alone have no line comments. It becomes + // a head comment for whatever follows. + return + } + if !yaml_parser_scan_line_comment(parser, comment_mark) { + ok = false + return + } + }() + + // Is it the flow sequence start indicator? + if buf[pos] == '[' { + return yaml_parser_fetch_flow_collection_start(parser, yaml_FLOW_SEQUENCE_START_TOKEN) + } + + // Is it the flow mapping start indicator? + if parser.buffer[parser.buffer_pos] == '{' { + return yaml_parser_fetch_flow_collection_start(parser, yaml_FLOW_MAPPING_START_TOKEN) + } + + // Is it the flow sequence end indicator? + if parser.buffer[parser.buffer_pos] == ']' { + return yaml_parser_fetch_flow_collection_end(parser, + yaml_FLOW_SEQUENCE_END_TOKEN) + } + + // Is it the flow mapping end indicator? + if parser.buffer[parser.buffer_pos] == '}' { + return yaml_parser_fetch_flow_collection_end(parser, + yaml_FLOW_MAPPING_END_TOKEN) + } + + // Is it the flow entry indicator? + if parser.buffer[parser.buffer_pos] == ',' { + return yaml_parser_fetch_flow_entry(parser) + } + + // Is it the block entry indicator? + if parser.buffer[parser.buffer_pos] == '-' && is_blankz(parser.buffer, parser.buffer_pos+1) { + return yaml_parser_fetch_block_entry(parser) + } + + // Is it the key indicator? + if parser.buffer[parser.buffer_pos] == '?' && (parser.flow_level > 0 || is_blankz(parser.buffer, parser.buffer_pos+1)) { + return yaml_parser_fetch_key(parser) + } + + // Is it the value indicator? + if parser.buffer[parser.buffer_pos] == ':' && (parser.flow_level > 0 || is_blankz(parser.buffer, parser.buffer_pos+1)) { + return yaml_parser_fetch_value(parser) + } + + // Is it an alias? + if parser.buffer[parser.buffer_pos] == '*' { + return yaml_parser_fetch_anchor(parser, yaml_ALIAS_TOKEN) + } + + // Is it an anchor? + if parser.buffer[parser.buffer_pos] == '&' { + return yaml_parser_fetch_anchor(parser, yaml_ANCHOR_TOKEN) + } + + // Is it a tag? + if parser.buffer[parser.buffer_pos] == '!' { + return yaml_parser_fetch_tag(parser) + } + + // Is it a literal scalar? + if parser.buffer[parser.buffer_pos] == '|' && parser.flow_level == 0 { + return yaml_parser_fetch_block_scalar(parser, true) + } + + // Is it a folded scalar? + if parser.buffer[parser.buffer_pos] == '>' && parser.flow_level == 0 { + return yaml_parser_fetch_block_scalar(parser, false) + } + + // Is it a single-quoted scalar? + if parser.buffer[parser.buffer_pos] == '\'' { + return yaml_parser_fetch_flow_scalar(parser, true) + } + + // Is it a double-quoted scalar? + if parser.buffer[parser.buffer_pos] == '"' { + return yaml_parser_fetch_flow_scalar(parser, false) + } + + // Is it a plain scalar? + // + // A plain scalar may start with any non-blank characters except + // + // '-', '?', ':', ',', '[', ']', '{', '}', + // '#', '&', '*', '!', '|', '>', '\'', '\"', + // '%', '@', '`'. + // + // In the block context (and, for the '-' indicator, in the flow context + // too), it may also start with the characters + // + // '-', '?', ':' + // + // if it is followed by a non-space character. + // + // The last rule is more restrictive than the specification requires. + // [Go] TODO Make this logic more reasonable. + //switch parser.buffer[parser.buffer_pos] { + //case '-', '?', ':', ',', '?', '-', ',', ':', ']', '[', '}', '{', '&', '#', '!', '*', '>', '|', '"', '\'', '@', '%', '-', '`': + //} + if !(is_blankz(parser.buffer, parser.buffer_pos) || parser.buffer[parser.buffer_pos] == '-' || + parser.buffer[parser.buffer_pos] == '?' || parser.buffer[parser.buffer_pos] == ':' || + parser.buffer[parser.buffer_pos] == ',' || parser.buffer[parser.buffer_pos] == '[' || + parser.buffer[parser.buffer_pos] == ']' || parser.buffer[parser.buffer_pos] == '{' || + parser.buffer[parser.buffer_pos] == '}' || parser.buffer[parser.buffer_pos] == '#' || + parser.buffer[parser.buffer_pos] == '&' || parser.buffer[parser.buffer_pos] == '*' || + parser.buffer[parser.buffer_pos] == '!' || parser.buffer[parser.buffer_pos] == '|' || + parser.buffer[parser.buffer_pos] == '>' || parser.buffer[parser.buffer_pos] == '\'' || + parser.buffer[parser.buffer_pos] == '"' || parser.buffer[parser.buffer_pos] == '%' || + parser.buffer[parser.buffer_pos] == '@' || parser.buffer[parser.buffer_pos] == '`') || + (parser.buffer[parser.buffer_pos] == '-' && !is_blank(parser.buffer, parser.buffer_pos+1)) || + (parser.flow_level == 0 && + (parser.buffer[parser.buffer_pos] == '?' || parser.buffer[parser.buffer_pos] == ':') && + !is_blankz(parser.buffer, parser.buffer_pos+1)) { + return yaml_parser_fetch_plain_scalar(parser) + } + + // If we don't determine the token type so far, it is an error. + return yaml_parser_set_scanner_error(parser, + "while scanning for the next token", parser.mark, + "found character that cannot start any token") +} + +func yaml_simple_key_is_valid(parser *yaml_parser_t, simple_key *yaml_simple_key_t) (valid, ok bool) { + if !simple_key.possible { + return false, true + } + + // The 1.2 specification says: + // + // "If the ? indicator is omitted, parsing needs to see past the + // implicit key to recognize it as such. To limit the amount of + // lookahead required, the “:” indicator must appear at most 1024 + // Unicode characters beyond the start of the key. In addition, the key + // is restricted to a single line." + // + if simple_key.mark.line < parser.mark.line || simple_key.mark.index+1024 < parser.mark.index { + // Check if the potential simple key to be removed is required. + if simple_key.required { + return false, yaml_parser_set_scanner_error(parser, + "while scanning a simple key", simple_key.mark, + "could not find expected ':'") + } + simple_key.possible = false + return false, true + } + return true, true +} + +// Check if a simple key may start at the current position and add it if +// needed. +func yaml_parser_save_simple_key(parser *yaml_parser_t) bool { + // A simple key is required at the current position if the scanner is in + // the block context and the current column coincides with the indentation + // level. + + required := parser.flow_level == 0 && parser.indent == parser.mark.column + + // + // If the current position may start a simple key, save it. + // + if parser.simple_key_allowed { + simple_key := yaml_simple_key_t{ + possible: true, + required: required, + token_number: parser.tokens_parsed + (len(parser.tokens) - parser.tokens_head), + mark: parser.mark, + } + + if !yaml_parser_remove_simple_key(parser) { + return false + } + parser.simple_keys[len(parser.simple_keys)-1] = simple_key + parser.simple_keys_by_tok[simple_key.token_number] = len(parser.simple_keys) - 1 + } + return true +} + +// Remove a potential simple key at the current flow level. +func yaml_parser_remove_simple_key(parser *yaml_parser_t) bool { + i := len(parser.simple_keys) - 1 + if parser.simple_keys[i].possible { + // If the key is required, it is an error. + if parser.simple_keys[i].required { + return yaml_parser_set_scanner_error(parser, + "while scanning a simple key", parser.simple_keys[i].mark, + "could not find expected ':'") + } + // Remove the key from the stack. + parser.simple_keys[i].possible = false + delete(parser.simple_keys_by_tok, parser.simple_keys[i].token_number) + } + return true +} + +// max_flow_level limits the flow_level +const max_flow_level = 10000 + +// Increase the flow level and resize the simple key list if needed. +func yaml_parser_increase_flow_level(parser *yaml_parser_t) bool { + // Reset the simple key on the next level. + parser.simple_keys = append(parser.simple_keys, yaml_simple_key_t{ + possible: false, + required: false, + token_number: parser.tokens_parsed + (len(parser.tokens) - parser.tokens_head), + mark: parser.mark, + }) + + // Increase the flow level. + parser.flow_level++ + if parser.flow_level > max_flow_level { + return yaml_parser_set_scanner_error(parser, + "while increasing flow level", parser.simple_keys[len(parser.simple_keys)-1].mark, + fmt.Sprintf("exceeded max depth of %d", max_flow_level)) + } + return true +} + +// Decrease the flow level. +func yaml_parser_decrease_flow_level(parser *yaml_parser_t) bool { + if parser.flow_level > 0 { + parser.flow_level-- + last := len(parser.simple_keys) - 1 + delete(parser.simple_keys_by_tok, parser.simple_keys[last].token_number) + parser.simple_keys = parser.simple_keys[:last] + } + return true +} + +// max_indents limits the indents stack size +const max_indents = 10000 + +// Push the current indentation level to the stack and set the new level +// the current column is greater than the indentation level. In this case, +// append or insert the specified token into the token queue. +func yaml_parser_roll_indent(parser *yaml_parser_t, column, number int, typ yaml_token_type_t, mark yaml_mark_t) bool { + // In the flow context, do nothing. + if parser.flow_level > 0 { + return true + } + + if parser.indent < column { + // Push the current indentation level to the stack and set the new + // indentation level. + parser.indents = append(parser.indents, parser.indent) + parser.indent = column + if len(parser.indents) > max_indents { + return yaml_parser_set_scanner_error(parser, + "while increasing indent level", parser.simple_keys[len(parser.simple_keys)-1].mark, + fmt.Sprintf("exceeded max depth of %d", max_indents)) + } + + // Create a token and insert it into the queue. + token := yaml_token_t{ + typ: typ, + start_mark: mark, + end_mark: mark, + } + if number > -1 { + number -= parser.tokens_parsed + } + yaml_insert_token(parser, number, &token) + } + return true +} + +// Pop indentation levels from the indents stack until the current level +// becomes less or equal to the column. For each indentation level, append +// the BLOCK-END token. +func yaml_parser_unroll_indent(parser *yaml_parser_t, column int, scan_mark yaml_mark_t) bool { + // In the flow context, do nothing. + if parser.flow_level > 0 { + return true + } + + block_mark := scan_mark + block_mark.index-- + + // Loop through the indentation levels in the stack. + for parser.indent > column { + + // [Go] Reposition the end token before potential following + // foot comments of parent blocks. For that, search + // backwards for recent comments that were at the same + // indent as the block that is ending now. + stop_index := block_mark.index + for i := len(parser.comments) - 1; i >= 0; i-- { + comment := &parser.comments[i] + + if comment.end_mark.index < stop_index { + // Don't go back beyond the start of the comment/whitespace scan, unless column < 0. + // If requested indent column is < 0, then the document is over and everything else + // is a foot anyway. + break + } + if comment.start_mark.column == parser.indent+1 { + // This is a good match. But maybe there's a former comment + // at that same indent level, so keep searching. + block_mark = comment.start_mark + } + + // While the end of the former comment matches with + // the start of the following one, we know there's + // nothing in between and scanning is still safe. + stop_index = comment.scan_mark.index + } + + // Create a token and append it to the queue. + token := yaml_token_t{ + typ: yaml_BLOCK_END_TOKEN, + start_mark: block_mark, + end_mark: block_mark, + } + yaml_insert_token(parser, -1, &token) + + // Pop the indentation level. + parser.indent = parser.indents[len(parser.indents)-1] + parser.indents = parser.indents[:len(parser.indents)-1] + } + return true +} + +// Initialize the scanner and produce the STREAM-START token. +func yaml_parser_fetch_stream_start(parser *yaml_parser_t) bool { + + // Set the initial indentation. + parser.indent = -1 + + // Initialize the simple key stack. + parser.simple_keys = append(parser.simple_keys, yaml_simple_key_t{}) + + parser.simple_keys_by_tok = make(map[int]int) + + // A simple key is allowed at the beginning of the stream. + parser.simple_key_allowed = true + + // We have started. + parser.stream_start_produced = true + + // Create the STREAM-START token and append it to the queue. + token := yaml_token_t{ + typ: yaml_STREAM_START_TOKEN, + start_mark: parser.mark, + end_mark: parser.mark, + encoding: parser.encoding, + } + yaml_insert_token(parser, -1, &token) + return true +} + +// Produce the STREAM-END token and shut down the scanner. +func yaml_parser_fetch_stream_end(parser *yaml_parser_t) bool { + + // Force new line. + if parser.mark.column != 0 { + parser.mark.column = 0 + parser.mark.line++ + } + + // Reset the indentation level. + if !yaml_parser_unroll_indent(parser, -1, parser.mark) { + return false + } + + // Reset simple keys. + if !yaml_parser_remove_simple_key(parser) { + return false + } + + parser.simple_key_allowed = false + + // Create the STREAM-END token and append it to the queue. + token := yaml_token_t{ + typ: yaml_STREAM_END_TOKEN, + start_mark: parser.mark, + end_mark: parser.mark, + } + yaml_insert_token(parser, -1, &token) + return true +} + +// Produce a VERSION-DIRECTIVE or TAG-DIRECTIVE token. +func yaml_parser_fetch_directive(parser *yaml_parser_t) bool { + // Reset the indentation level. + if !yaml_parser_unroll_indent(parser, -1, parser.mark) { + return false + } + + // Reset simple keys. + if !yaml_parser_remove_simple_key(parser) { + return false + } + + parser.simple_key_allowed = false + + // Create the YAML-DIRECTIVE or TAG-DIRECTIVE token. + token := yaml_token_t{} + if !yaml_parser_scan_directive(parser, &token) { + return false + } + // Append the token to the queue. + yaml_insert_token(parser, -1, &token) + return true +} + +// Produce the DOCUMENT-START or DOCUMENT-END token. +func yaml_parser_fetch_document_indicator(parser *yaml_parser_t, typ yaml_token_type_t) bool { + // Reset the indentation level. + if !yaml_parser_unroll_indent(parser, -1, parser.mark) { + return false + } + + // Reset simple keys. + if !yaml_parser_remove_simple_key(parser) { + return false + } + + parser.simple_key_allowed = false + + // Consume the token. + start_mark := parser.mark + + skip(parser) + skip(parser) + skip(parser) + + end_mark := parser.mark + + // Create the DOCUMENT-START or DOCUMENT-END token. + token := yaml_token_t{ + typ: typ, + start_mark: start_mark, + end_mark: end_mark, + } + // Append the token to the queue. + yaml_insert_token(parser, -1, &token) + return true +} + +// Produce the FLOW-SEQUENCE-START or FLOW-MAPPING-START token. +func yaml_parser_fetch_flow_collection_start(parser *yaml_parser_t, typ yaml_token_type_t) bool { + + // The indicators '[' and '{' may start a simple key. + if !yaml_parser_save_simple_key(parser) { + return false + } + + // Increase the flow level. + if !yaml_parser_increase_flow_level(parser) { + return false + } + + // A simple key may follow the indicators '[' and '{'. + parser.simple_key_allowed = true + + // Consume the token. + start_mark := parser.mark + skip(parser) + end_mark := parser.mark + + // Create the FLOW-SEQUENCE-START of FLOW-MAPPING-START token. + token := yaml_token_t{ + typ: typ, + start_mark: start_mark, + end_mark: end_mark, + } + // Append the token to the queue. + yaml_insert_token(parser, -1, &token) + return true +} + +// Produce the FLOW-SEQUENCE-END or FLOW-MAPPING-END token. +func yaml_parser_fetch_flow_collection_end(parser *yaml_parser_t, typ yaml_token_type_t) bool { + // Reset any potential simple key on the current flow level. + if !yaml_parser_remove_simple_key(parser) { + return false + } + + // Decrease the flow level. + if !yaml_parser_decrease_flow_level(parser) { + return false + } + + // No simple keys after the indicators ']' and '}'. + parser.simple_key_allowed = false + + // Consume the token. + + start_mark := parser.mark + skip(parser) + end_mark := parser.mark + + // Create the FLOW-SEQUENCE-END of FLOW-MAPPING-END token. + token := yaml_token_t{ + typ: typ, + start_mark: start_mark, + end_mark: end_mark, + } + // Append the token to the queue. + yaml_insert_token(parser, -1, &token) + return true +} + +// Produce the FLOW-ENTRY token. +func yaml_parser_fetch_flow_entry(parser *yaml_parser_t) bool { + // Reset any potential simple keys on the current flow level. + if !yaml_parser_remove_simple_key(parser) { + return false + } + + // Simple keys are allowed after ','. + parser.simple_key_allowed = true + + // Consume the token. + start_mark := parser.mark + skip(parser) + end_mark := parser.mark + + // Create the FLOW-ENTRY token and append it to the queue. + token := yaml_token_t{ + typ: yaml_FLOW_ENTRY_TOKEN, + start_mark: start_mark, + end_mark: end_mark, + } + yaml_insert_token(parser, -1, &token) + return true +} + +// Produce the BLOCK-ENTRY token. +func yaml_parser_fetch_block_entry(parser *yaml_parser_t) bool { + // Check if the scanner is in the block context. + if parser.flow_level == 0 { + // Check if we are allowed to start a new entry. + if !parser.simple_key_allowed { + return yaml_parser_set_scanner_error(parser, "", parser.mark, + "block sequence entries are not allowed in this context") + } + // Add the BLOCK-SEQUENCE-START token if needed. + if !yaml_parser_roll_indent(parser, parser.mark.column, -1, yaml_BLOCK_SEQUENCE_START_TOKEN, parser.mark) { + return false + } + } else { + // It is an error for the '-' indicator to occur in the flow context, + // but we let the Parser detect and report about it because the Parser + // is able to point to the context. + } + + // Reset any potential simple keys on the current flow level. + if !yaml_parser_remove_simple_key(parser) { + return false + } + + // Simple keys are allowed after '-'. + parser.simple_key_allowed = true + + // Consume the token. + start_mark := parser.mark + skip(parser) + end_mark := parser.mark + + // Create the BLOCK-ENTRY token and append it to the queue. + token := yaml_token_t{ + typ: yaml_BLOCK_ENTRY_TOKEN, + start_mark: start_mark, + end_mark: end_mark, + } + yaml_insert_token(parser, -1, &token) + return true +} + +// Produce the KEY token. +func yaml_parser_fetch_key(parser *yaml_parser_t) bool { + + // In the block context, additional checks are required. + if parser.flow_level == 0 { + // Check if we are allowed to start a new key (not nessesary simple). + if !parser.simple_key_allowed { + return yaml_parser_set_scanner_error(parser, "", parser.mark, + "mapping keys are not allowed in this context") + } + // Add the BLOCK-MAPPING-START token if needed. + if !yaml_parser_roll_indent(parser, parser.mark.column, -1, yaml_BLOCK_MAPPING_START_TOKEN, parser.mark) { + return false + } + } + + // Reset any potential simple keys on the current flow level. + if !yaml_parser_remove_simple_key(parser) { + return false + } + + // Simple keys are allowed after '?' in the block context. + parser.simple_key_allowed = parser.flow_level == 0 + + // Consume the token. + start_mark := parser.mark + skip(parser) + end_mark := parser.mark + + // Create the KEY token and append it to the queue. + token := yaml_token_t{ + typ: yaml_KEY_TOKEN, + start_mark: start_mark, + end_mark: end_mark, + } + yaml_insert_token(parser, -1, &token) + return true +} + +// Produce the VALUE token. +func yaml_parser_fetch_value(parser *yaml_parser_t) bool { + + simple_key := &parser.simple_keys[len(parser.simple_keys)-1] + + // Have we found a simple key? + if valid, ok := yaml_simple_key_is_valid(parser, simple_key); !ok { + return false + + } else if valid { + + // Create the KEY token and insert it into the queue. + token := yaml_token_t{ + typ: yaml_KEY_TOKEN, + start_mark: simple_key.mark, + end_mark: simple_key.mark, + } + yaml_insert_token(parser, simple_key.token_number-parser.tokens_parsed, &token) + + // In the block context, we may need to add the BLOCK-MAPPING-START token. + if !yaml_parser_roll_indent(parser, simple_key.mark.column, + simple_key.token_number, + yaml_BLOCK_MAPPING_START_TOKEN, simple_key.mark) { + return false + } + + // Remove the simple key. + simple_key.possible = false + delete(parser.simple_keys_by_tok, simple_key.token_number) + + // A simple key cannot follow another simple key. + parser.simple_key_allowed = false + + } else { + // The ':' indicator follows a complex key. + + // In the block context, extra checks are required. + if parser.flow_level == 0 { + + // Check if we are allowed to start a complex value. + if !parser.simple_key_allowed { + return yaml_parser_set_scanner_error(parser, "", parser.mark, + "mapping values are not allowed in this context") + } + + // Add the BLOCK-MAPPING-START token if needed. + if !yaml_parser_roll_indent(parser, parser.mark.column, -1, yaml_BLOCK_MAPPING_START_TOKEN, parser.mark) { + return false + } + } + + // Simple keys after ':' are allowed in the block context. + parser.simple_key_allowed = parser.flow_level == 0 + } + + // Consume the token. + start_mark := parser.mark + skip(parser) + end_mark := parser.mark + + // Create the VALUE token and append it to the queue. + token := yaml_token_t{ + typ: yaml_VALUE_TOKEN, + start_mark: start_mark, + end_mark: end_mark, + } + yaml_insert_token(parser, -1, &token) + return true +} + +// Produce the ALIAS or ANCHOR token. +func yaml_parser_fetch_anchor(parser *yaml_parser_t, typ yaml_token_type_t) bool { + // An anchor or an alias could be a simple key. + if !yaml_parser_save_simple_key(parser) { + return false + } + + // A simple key cannot follow an anchor or an alias. + parser.simple_key_allowed = false + + // Create the ALIAS or ANCHOR token and append it to the queue. + var token yaml_token_t + if !yaml_parser_scan_anchor(parser, &token, typ) { + return false + } + yaml_insert_token(parser, -1, &token) + return true +} + +// Produce the TAG token. +func yaml_parser_fetch_tag(parser *yaml_parser_t) bool { + // A tag could be a simple key. + if !yaml_parser_save_simple_key(parser) { + return false + } + + // A simple key cannot follow a tag. + parser.simple_key_allowed = false + + // Create the TAG token and append it to the queue. + var token yaml_token_t + if !yaml_parser_scan_tag(parser, &token) { + return false + } + yaml_insert_token(parser, -1, &token) + return true +} + +// Produce the SCALAR(...,literal) or SCALAR(...,folded) tokens. +func yaml_parser_fetch_block_scalar(parser *yaml_parser_t, literal bool) bool { + // Remove any potential simple keys. + if !yaml_parser_remove_simple_key(parser) { + return false + } + + // A simple key may follow a block scalar. + parser.simple_key_allowed = true + + // Create the SCALAR token and append it to the queue. + var token yaml_token_t + if !yaml_parser_scan_block_scalar(parser, &token, literal) { + return false + } + yaml_insert_token(parser, -1, &token) + return true +} + +// Produce the SCALAR(...,single-quoted) or SCALAR(...,double-quoted) tokens. +func yaml_parser_fetch_flow_scalar(parser *yaml_parser_t, single bool) bool { + // A plain scalar could be a simple key. + if !yaml_parser_save_simple_key(parser) { + return false + } + + // A simple key cannot follow a flow scalar. + parser.simple_key_allowed = false + + // Create the SCALAR token and append it to the queue. + var token yaml_token_t + if !yaml_parser_scan_flow_scalar(parser, &token, single) { + return false + } + yaml_insert_token(parser, -1, &token) + return true +} + +// Produce the SCALAR(...,plain) token. +func yaml_parser_fetch_plain_scalar(parser *yaml_parser_t) bool { + // A plain scalar could be a simple key. + if !yaml_parser_save_simple_key(parser) { + return false + } + + // A simple key cannot follow a flow scalar. + parser.simple_key_allowed = false + + // Create the SCALAR token and append it to the queue. + var token yaml_token_t + if !yaml_parser_scan_plain_scalar(parser, &token) { + return false + } + yaml_insert_token(parser, -1, &token) + return true +} + +// Eat whitespaces and comments until the next token is found. +func yaml_parser_scan_to_next_token(parser *yaml_parser_t) bool { + + scan_mark := parser.mark + + // Until the next token is not found. + for { + // Allow the BOM mark to start a line. + if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { + return false + } + if parser.mark.column == 0 && is_bom(parser.buffer, parser.buffer_pos) { + skip(parser) + } + + // Eat whitespaces. + // Tabs are allowed: + // - in the flow context + // - in the block context, but not at the beginning of the line or + // after '-', '?', or ':' (complex value). + if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { + return false + } + + for parser.buffer[parser.buffer_pos] == ' ' || ((parser.flow_level > 0 || !parser.simple_key_allowed) && parser.buffer[parser.buffer_pos] == '\t') { + skip(parser) + if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { + return false + } + } + + // Check if we just had a line comment under a sequence entry that + // looks more like a header to the following content. Similar to this: + // + // - # The comment + // - Some data + // + // If so, transform the line comment to a head comment and reposition. + if len(parser.comments) > 0 && len(parser.tokens) > 1 { + tokenA := parser.tokens[len(parser.tokens)-2] + tokenB := parser.tokens[len(parser.tokens)-1] + comment := &parser.comments[len(parser.comments)-1] + if tokenA.typ == yaml_BLOCK_SEQUENCE_START_TOKEN && tokenB.typ == yaml_BLOCK_ENTRY_TOKEN && len(comment.line) > 0 && !is_break(parser.buffer, parser.buffer_pos) { + // If it was in the prior line, reposition so it becomes a + // header of the follow up token. Otherwise, keep it in place + // so it becomes a header of the former. + comment.head = comment.line + comment.line = nil + if comment.start_mark.line == parser.mark.line-1 { + comment.token_mark = parser.mark + } + } + } + + // Eat a comment until a line break. + if parser.buffer[parser.buffer_pos] == '#' { + if !yaml_parser_scan_comments(parser, scan_mark) { + return false + } + } + + // If it is a line break, eat it. + if is_break(parser.buffer, parser.buffer_pos) { + if parser.unread < 2 && !yaml_parser_update_buffer(parser, 2) { + return false + } + skip_line(parser) + + // In the block context, a new line may start a simple key. + if parser.flow_level == 0 { + parser.simple_key_allowed = true + } + } else { + break // We have found a token. + } + } + + return true +} + +// Scan a YAML-DIRECTIVE or TAG-DIRECTIVE token. +// +// Scope: +// %YAML 1.1 # a comment \n +// ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +// %TAG !yaml! tag:yaml.org,2002: \n +// ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +// +func yaml_parser_scan_directive(parser *yaml_parser_t, token *yaml_token_t) bool { + // Eat '%'. + start_mark := parser.mark + skip(parser) + + // Scan the directive name. + var name []byte + if !yaml_parser_scan_directive_name(parser, start_mark, &name) { + return false + } + + // Is it a YAML directive? + if bytes.Equal(name, []byte("YAML")) { + // Scan the VERSION directive value. + var major, minor int8 + if !yaml_parser_scan_version_directive_value(parser, start_mark, &major, &minor) { + return false + } + end_mark := parser.mark + + // Create a VERSION-DIRECTIVE token. + *token = yaml_token_t{ + typ: yaml_VERSION_DIRECTIVE_TOKEN, + start_mark: start_mark, + end_mark: end_mark, + major: major, + minor: minor, + } + + // Is it a TAG directive? + } else if bytes.Equal(name, []byte("TAG")) { + // Scan the TAG directive value. + var handle, prefix []byte + if !yaml_parser_scan_tag_directive_value(parser, start_mark, &handle, &prefix) { + return false + } + end_mark := parser.mark + + // Create a TAG-DIRECTIVE token. + *token = yaml_token_t{ + typ: yaml_TAG_DIRECTIVE_TOKEN, + start_mark: start_mark, + end_mark: end_mark, + value: handle, + prefix: prefix, + } + + // Unknown directive. + } else { + yaml_parser_set_scanner_error(parser, "while scanning a directive", + start_mark, "found unknown directive name") + return false + } + + // Eat the rest of the line including any comments. + if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { + return false + } + + for is_blank(parser.buffer, parser.buffer_pos) { + skip(parser) + if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { + return false + } + } + + if parser.buffer[parser.buffer_pos] == '#' { + // [Go] Discard this inline comment for the time being. + //if !yaml_parser_scan_line_comment(parser, start_mark) { + // return false + //} + for !is_breakz(parser.buffer, parser.buffer_pos) { + skip(parser) + if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { + return false + } + } + } + + // Check if we are at the end of the line. + if !is_breakz(parser.buffer, parser.buffer_pos) { + yaml_parser_set_scanner_error(parser, "while scanning a directive", + start_mark, "did not find expected comment or line break") + return false + } + + // Eat a line break. + if is_break(parser.buffer, parser.buffer_pos) { + if parser.unread < 2 && !yaml_parser_update_buffer(parser, 2) { + return false + } + skip_line(parser) + } + + return true +} + +// Scan the directive name. +// +// Scope: +// %YAML 1.1 # a comment \n +// ^^^^ +// %TAG !yaml! tag:yaml.org,2002: \n +// ^^^ +// +func yaml_parser_scan_directive_name(parser *yaml_parser_t, start_mark yaml_mark_t, name *[]byte) bool { + // Consume the directive name. + if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { + return false + } + + var s []byte + for is_alpha(parser.buffer, parser.buffer_pos) { + s = read(parser, s) + if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { + return false + } + } + + // Check if the name is empty. + if len(s) == 0 { + yaml_parser_set_scanner_error(parser, "while scanning a directive", + start_mark, "could not find expected directive name") + return false + } + + // Check for an blank character after the name. + if !is_blankz(parser.buffer, parser.buffer_pos) { + yaml_parser_set_scanner_error(parser, "while scanning a directive", + start_mark, "found unexpected non-alphabetical character") + return false + } + *name = s + return true +} + +// Scan the value of VERSION-DIRECTIVE. +// +// Scope: +// %YAML 1.1 # a comment \n +// ^^^^^^ +func yaml_parser_scan_version_directive_value(parser *yaml_parser_t, start_mark yaml_mark_t, major, minor *int8) bool { + // Eat whitespaces. + if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { + return false + } + for is_blank(parser.buffer, parser.buffer_pos) { + skip(parser) + if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { + return false + } + } + + // Consume the major version number. + if !yaml_parser_scan_version_directive_number(parser, start_mark, major) { + return false + } + + // Eat '.'. + if parser.buffer[parser.buffer_pos] != '.' { + return yaml_parser_set_scanner_error(parser, "while scanning a %YAML directive", + start_mark, "did not find expected digit or '.' character") + } + + skip(parser) + + // Consume the minor version number. + if !yaml_parser_scan_version_directive_number(parser, start_mark, minor) { + return false + } + return true +} + +const max_number_length = 2 + +// Scan the version number of VERSION-DIRECTIVE. +// +// Scope: +// %YAML 1.1 # a comment \n +// ^ +// %YAML 1.1 # a comment \n +// ^ +func yaml_parser_scan_version_directive_number(parser *yaml_parser_t, start_mark yaml_mark_t, number *int8) bool { + + // Repeat while the next character is digit. + if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { + return false + } + var value, length int8 + for is_digit(parser.buffer, parser.buffer_pos) { + // Check if the number is too long. + length++ + if length > max_number_length { + return yaml_parser_set_scanner_error(parser, "while scanning a %YAML directive", + start_mark, "found extremely long version number") + } + value = value*10 + int8(as_digit(parser.buffer, parser.buffer_pos)) + skip(parser) + if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { + return false + } + } + + // Check if the number was present. + if length == 0 { + return yaml_parser_set_scanner_error(parser, "while scanning a %YAML directive", + start_mark, "did not find expected version number") + } + *number = value + return true +} + +// Scan the value of a TAG-DIRECTIVE token. +// +// Scope: +// %TAG !yaml! tag:yaml.org,2002: \n +// ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +// +func yaml_parser_scan_tag_directive_value(parser *yaml_parser_t, start_mark yaml_mark_t, handle, prefix *[]byte) bool { + var handle_value, prefix_value []byte + + // Eat whitespaces. + if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { + return false + } + + for is_blank(parser.buffer, parser.buffer_pos) { + skip(parser) + if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { + return false + } + } + + // Scan a handle. + if !yaml_parser_scan_tag_handle(parser, true, start_mark, &handle_value) { + return false + } + + // Expect a whitespace. + if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { + return false + } + if !is_blank(parser.buffer, parser.buffer_pos) { + yaml_parser_set_scanner_error(parser, "while scanning a %TAG directive", + start_mark, "did not find expected whitespace") + return false + } + + // Eat whitespaces. + for is_blank(parser.buffer, parser.buffer_pos) { + skip(parser) + if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { + return false + } + } + + // Scan a prefix. + if !yaml_parser_scan_tag_uri(parser, true, nil, start_mark, &prefix_value) { + return false + } + + // Expect a whitespace or line break. + if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { + return false + } + if !is_blankz(parser.buffer, parser.buffer_pos) { + yaml_parser_set_scanner_error(parser, "while scanning a %TAG directive", + start_mark, "did not find expected whitespace or line break") + return false + } + + *handle = handle_value + *prefix = prefix_value + return true +} + +func yaml_parser_scan_anchor(parser *yaml_parser_t, token *yaml_token_t, typ yaml_token_type_t) bool { + var s []byte + + // Eat the indicator character. + start_mark := parser.mark + skip(parser) + + // Consume the value. + if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { + return false + } + + for is_alpha(parser.buffer, parser.buffer_pos) { + s = read(parser, s) + if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { + return false + } + } + + end_mark := parser.mark + + /* + * Check if length of the anchor is greater than 0 and it is followed by + * a whitespace character or one of the indicators: + * + * '?', ':', ',', ']', '}', '%', '@', '`'. + */ + + if len(s) == 0 || + !(is_blankz(parser.buffer, parser.buffer_pos) || parser.buffer[parser.buffer_pos] == '?' || + parser.buffer[parser.buffer_pos] == ':' || parser.buffer[parser.buffer_pos] == ',' || + parser.buffer[parser.buffer_pos] == ']' || parser.buffer[parser.buffer_pos] == '}' || + parser.buffer[parser.buffer_pos] == '%' || parser.buffer[parser.buffer_pos] == '@' || + parser.buffer[parser.buffer_pos] == '`') { + context := "while scanning an alias" + if typ == yaml_ANCHOR_TOKEN { + context = "while scanning an anchor" + } + yaml_parser_set_scanner_error(parser, context, start_mark, + "did not find expected alphabetic or numeric character") + return false + } + + // Create a token. + *token = yaml_token_t{ + typ: typ, + start_mark: start_mark, + end_mark: end_mark, + value: s, + } + + return true +} + +/* + * Scan a TAG token. + */ + +func yaml_parser_scan_tag(parser *yaml_parser_t, token *yaml_token_t) bool { + var handle, suffix []byte + + start_mark := parser.mark + + // Check if the tag is in the canonical form. + if parser.unread < 2 && !yaml_parser_update_buffer(parser, 2) { + return false + } + + if parser.buffer[parser.buffer_pos+1] == '<' { + // Keep the handle as '' + + // Eat '!<' + skip(parser) + skip(parser) + + // Consume the tag value. + if !yaml_parser_scan_tag_uri(parser, false, nil, start_mark, &suffix) { + return false + } + + // Check for '>' and eat it. + if parser.buffer[parser.buffer_pos] != '>' { + yaml_parser_set_scanner_error(parser, "while scanning a tag", + start_mark, "did not find the expected '>'") + return false + } + + skip(parser) + } else { + // The tag has either the '!suffix' or the '!handle!suffix' form. + + // First, try to scan a handle. + if !yaml_parser_scan_tag_handle(parser, false, start_mark, &handle) { + return false + } + + // Check if it is, indeed, handle. + if handle[0] == '!' && len(handle) > 1 && handle[len(handle)-1] == '!' { + // Scan the suffix now. + if !yaml_parser_scan_tag_uri(parser, false, nil, start_mark, &suffix) { + return false + } + } else { + // It wasn't a handle after all. Scan the rest of the tag. + if !yaml_parser_scan_tag_uri(parser, false, handle, start_mark, &suffix) { + return false + } + + // Set the handle to '!'. + handle = []byte{'!'} + + // A special case: the '!' tag. Set the handle to '' and the + // suffix to '!'. + if len(suffix) == 0 { + handle, suffix = suffix, handle + } + } + } + + // Check the character which ends the tag. + if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { + return false + } + if !is_blankz(parser.buffer, parser.buffer_pos) { + yaml_parser_set_scanner_error(parser, "while scanning a tag", + start_mark, "did not find expected whitespace or line break") + return false + } + + end_mark := parser.mark + + // Create a token. + *token = yaml_token_t{ + typ: yaml_TAG_TOKEN, + start_mark: start_mark, + end_mark: end_mark, + value: handle, + suffix: suffix, + } + return true +} + +// Scan a tag handle. +func yaml_parser_scan_tag_handle(parser *yaml_parser_t, directive bool, start_mark yaml_mark_t, handle *[]byte) bool { + // Check the initial '!' character. + if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { + return false + } + if parser.buffer[parser.buffer_pos] != '!' { + yaml_parser_set_scanner_tag_error(parser, directive, + start_mark, "did not find expected '!'") + return false + } + + var s []byte + + // Copy the '!' character. + s = read(parser, s) + + // Copy all subsequent alphabetical and numerical characters. + if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { + return false + } + for is_alpha(parser.buffer, parser.buffer_pos) { + s = read(parser, s) + if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { + return false + } + } + + // Check if the trailing character is '!' and copy it. + if parser.buffer[parser.buffer_pos] == '!' { + s = read(parser, s) + } else { + // It's either the '!' tag or not really a tag handle. If it's a %TAG + // directive, it's an error. If it's a tag token, it must be a part of URI. + if directive && string(s) != "!" { + yaml_parser_set_scanner_tag_error(parser, directive, + start_mark, "did not find expected '!'") + return false + } + } + + *handle = s + return true +} + +// Scan a tag. +func yaml_parser_scan_tag_uri(parser *yaml_parser_t, directive bool, head []byte, start_mark yaml_mark_t, uri *[]byte) bool { + //size_t length = head ? strlen((char *)head) : 0 + var s []byte + hasTag := len(head) > 0 + + // Copy the head if needed. + // + // Note that we don't copy the leading '!' character. + if len(head) > 1 { + s = append(s, head[1:]...) + } + + // Scan the tag. + if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { + return false + } + + // The set of characters that may appear in URI is as follows: + // + // '0'-'9', 'A'-'Z', 'a'-'z', '_', '-', ';', '/', '?', ':', '@', '&', + // '=', '+', '$', ',', '.', '!', '~', '*', '\'', '(', ')', '[', ']', + // '%'. + // [Go] TODO Convert this into more reasonable logic. + for is_alpha(parser.buffer, parser.buffer_pos) || parser.buffer[parser.buffer_pos] == ';' || + parser.buffer[parser.buffer_pos] == '/' || parser.buffer[parser.buffer_pos] == '?' || + parser.buffer[parser.buffer_pos] == ':' || parser.buffer[parser.buffer_pos] == '@' || + parser.buffer[parser.buffer_pos] == '&' || parser.buffer[parser.buffer_pos] == '=' || + parser.buffer[parser.buffer_pos] == '+' || parser.buffer[parser.buffer_pos] == '$' || + parser.buffer[parser.buffer_pos] == ',' || parser.buffer[parser.buffer_pos] == '.' || + parser.buffer[parser.buffer_pos] == '!' || parser.buffer[parser.buffer_pos] == '~' || + parser.buffer[parser.buffer_pos] == '*' || parser.buffer[parser.buffer_pos] == '\'' || + parser.buffer[parser.buffer_pos] == '(' || parser.buffer[parser.buffer_pos] == ')' || + parser.buffer[parser.buffer_pos] == '[' || parser.buffer[parser.buffer_pos] == ']' || + parser.buffer[parser.buffer_pos] == '%' { + // Check if it is a URI-escape sequence. + if parser.buffer[parser.buffer_pos] == '%' { + if !yaml_parser_scan_uri_escapes(parser, directive, start_mark, &s) { + return false + } + } else { + s = read(parser, s) + } + if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { + return false + } + hasTag = true + } + + if !hasTag { + yaml_parser_set_scanner_tag_error(parser, directive, + start_mark, "did not find expected tag URI") + return false + } + *uri = s + return true +} + +// Decode an URI-escape sequence corresponding to a single UTF-8 character. +func yaml_parser_scan_uri_escapes(parser *yaml_parser_t, directive bool, start_mark yaml_mark_t, s *[]byte) bool { + + // Decode the required number of characters. + w := 1024 + for w > 0 { + // Check for a URI-escaped octet. + if parser.unread < 3 && !yaml_parser_update_buffer(parser, 3) { + return false + } + + if !(parser.buffer[parser.buffer_pos] == '%' && + is_hex(parser.buffer, parser.buffer_pos+1) && + is_hex(parser.buffer, parser.buffer_pos+2)) { + return yaml_parser_set_scanner_tag_error(parser, directive, + start_mark, "did not find URI escaped octet") + } + + // Get the octet. + octet := byte((as_hex(parser.buffer, parser.buffer_pos+1) << 4) + as_hex(parser.buffer, parser.buffer_pos+2)) + + // If it is the leading octet, determine the length of the UTF-8 sequence. + if w == 1024 { + w = width(octet) + if w == 0 { + return yaml_parser_set_scanner_tag_error(parser, directive, + start_mark, "found an incorrect leading UTF-8 octet") + } + } else { + // Check if the trailing octet is correct. + if octet&0xC0 != 0x80 { + return yaml_parser_set_scanner_tag_error(parser, directive, + start_mark, "found an incorrect trailing UTF-8 octet") + } + } + + // Copy the octet and move the pointers. + *s = append(*s, octet) + skip(parser) + skip(parser) + skip(parser) + w-- + } + return true +} + +// Scan a block scalar. +func yaml_parser_scan_block_scalar(parser *yaml_parser_t, token *yaml_token_t, literal bool) bool { + // Eat the indicator '|' or '>'. + start_mark := parser.mark + skip(parser) + + // Scan the additional block scalar indicators. + if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { + return false + } + + // Check for a chomping indicator. + var chomping, increment int + if parser.buffer[parser.buffer_pos] == '+' || parser.buffer[parser.buffer_pos] == '-' { + // Set the chomping method and eat the indicator. + if parser.buffer[parser.buffer_pos] == '+' { + chomping = +1 + } else { + chomping = -1 + } + skip(parser) + + // Check for an indentation indicator. + if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { + return false + } + if is_digit(parser.buffer, parser.buffer_pos) { + // Check that the indentation is greater than 0. + if parser.buffer[parser.buffer_pos] == '0' { + yaml_parser_set_scanner_error(parser, "while scanning a block scalar", + start_mark, "found an indentation indicator equal to 0") + return false + } + + // Get the indentation level and eat the indicator. + increment = as_digit(parser.buffer, parser.buffer_pos) + skip(parser) + } + + } else if is_digit(parser.buffer, parser.buffer_pos) { + // Do the same as above, but in the opposite order. + + if parser.buffer[parser.buffer_pos] == '0' { + yaml_parser_set_scanner_error(parser, "while scanning a block scalar", + start_mark, "found an indentation indicator equal to 0") + return false + } + increment = as_digit(parser.buffer, parser.buffer_pos) + skip(parser) + + if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { + return false + } + if parser.buffer[parser.buffer_pos] == '+' || parser.buffer[parser.buffer_pos] == '-' { + if parser.buffer[parser.buffer_pos] == '+' { + chomping = +1 + } else { + chomping = -1 + } + skip(parser) + } + } + + // Eat whitespaces and comments to the end of the line. + if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { + return false + } + for is_blank(parser.buffer, parser.buffer_pos) { + skip(parser) + if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { + return false + } + } + if parser.buffer[parser.buffer_pos] == '#' { + if !yaml_parser_scan_line_comment(parser, start_mark) { + return false + } + for !is_breakz(parser.buffer, parser.buffer_pos) { + skip(parser) + if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { + return false + } + } + } + + // Check if we are at the end of the line. + if !is_breakz(parser.buffer, parser.buffer_pos) { + yaml_parser_set_scanner_error(parser, "while scanning a block scalar", + start_mark, "did not find expected comment or line break") + return false + } + + // Eat a line break. + if is_break(parser.buffer, parser.buffer_pos) { + if parser.unread < 2 && !yaml_parser_update_buffer(parser, 2) { + return false + } + skip_line(parser) + } + + end_mark := parser.mark + + // Set the indentation level if it was specified. + var indent int + if increment > 0 { + if parser.indent >= 0 { + indent = parser.indent + increment + } else { + indent = increment + } + } + + // Scan the leading line breaks and determine the indentation level if needed. + var s, leading_break, trailing_breaks []byte + if !yaml_parser_scan_block_scalar_breaks(parser, &indent, &trailing_breaks, start_mark, &end_mark) { + return false + } + + // Scan the block scalar content. + if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { + return false + } + var leading_blank, trailing_blank bool + for parser.mark.column == indent && !is_z(parser.buffer, parser.buffer_pos) { + // We are at the beginning of a non-empty line. + + // Is it a trailing whitespace? + trailing_blank = is_blank(parser.buffer, parser.buffer_pos) + + // Check if we need to fold the leading line break. + if !literal && !leading_blank && !trailing_blank && len(leading_break) > 0 && leading_break[0] == '\n' { + // Do we need to join the lines by space? + if len(trailing_breaks) == 0 { + s = append(s, ' ') + } + } else { + s = append(s, leading_break...) + } + leading_break = leading_break[:0] + + // Append the remaining line breaks. + s = append(s, trailing_breaks...) + trailing_breaks = trailing_breaks[:0] + + // Is it a leading whitespace? + leading_blank = is_blank(parser.buffer, parser.buffer_pos) + + // Consume the current line. + for !is_breakz(parser.buffer, parser.buffer_pos) { + s = read(parser, s) + if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { + return false + } + } + + // Consume the line break. + if parser.unread < 2 && !yaml_parser_update_buffer(parser, 2) { + return false + } + + leading_break = read_line(parser, leading_break) + + // Eat the following indentation spaces and line breaks. + if !yaml_parser_scan_block_scalar_breaks(parser, &indent, &trailing_breaks, start_mark, &end_mark) { + return false + } + } + + // Chomp the tail. + if chomping != -1 { + s = append(s, leading_break...) + } + if chomping == 1 { + s = append(s, trailing_breaks...) + } + + // Create a token. + *token = yaml_token_t{ + typ: yaml_SCALAR_TOKEN, + start_mark: start_mark, + end_mark: end_mark, + value: s, + style: yaml_LITERAL_SCALAR_STYLE, + } + if !literal { + token.style = yaml_FOLDED_SCALAR_STYLE + } + return true +} + +// Scan indentation spaces and line breaks for a block scalar. Determine the +// indentation level if needed. +func yaml_parser_scan_block_scalar_breaks(parser *yaml_parser_t, indent *int, breaks *[]byte, start_mark yaml_mark_t, end_mark *yaml_mark_t) bool { + *end_mark = parser.mark + + // Eat the indentation spaces and line breaks. + max_indent := 0 + for { + // Eat the indentation spaces. + if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { + return false + } + for (*indent == 0 || parser.mark.column < *indent) && is_space(parser.buffer, parser.buffer_pos) { + skip(parser) + if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { + return false + } + } + if parser.mark.column > max_indent { + max_indent = parser.mark.column + } + + // Check for a tab character messing the indentation. + if (*indent == 0 || parser.mark.column < *indent) && is_tab(parser.buffer, parser.buffer_pos) { + return yaml_parser_set_scanner_error(parser, "while scanning a block scalar", + start_mark, "found a tab character where an indentation space is expected") + } + + // Have we found a non-empty line? + if !is_break(parser.buffer, parser.buffer_pos) { + break + } + + // Consume the line break. + if parser.unread < 2 && !yaml_parser_update_buffer(parser, 2) { + return false + } + // [Go] Should really be returning breaks instead. + *breaks = read_line(parser, *breaks) + *end_mark = parser.mark + } + + // Determine the indentation level if needed. + if *indent == 0 { + *indent = max_indent + if *indent < parser.indent+1 { + *indent = parser.indent + 1 + } + if *indent < 1 { + *indent = 1 + } + } + return true +} + +// Scan a quoted scalar. +func yaml_parser_scan_flow_scalar(parser *yaml_parser_t, token *yaml_token_t, single bool) bool { + // Eat the left quote. + start_mark := parser.mark + skip(parser) + + // Consume the content of the quoted scalar. + var s, leading_break, trailing_breaks, whitespaces []byte + for { + // Check that there are no document indicators at the beginning of the line. + if parser.unread < 4 && !yaml_parser_update_buffer(parser, 4) { + return false + } + + if parser.mark.column == 0 && + ((parser.buffer[parser.buffer_pos+0] == '-' && + parser.buffer[parser.buffer_pos+1] == '-' && + parser.buffer[parser.buffer_pos+2] == '-') || + (parser.buffer[parser.buffer_pos+0] == '.' && + parser.buffer[parser.buffer_pos+1] == '.' && + parser.buffer[parser.buffer_pos+2] == '.')) && + is_blankz(parser.buffer, parser.buffer_pos+3) { + yaml_parser_set_scanner_error(parser, "while scanning a quoted scalar", + start_mark, "found unexpected document indicator") + return false + } + + // Check for EOF. + if is_z(parser.buffer, parser.buffer_pos) { + yaml_parser_set_scanner_error(parser, "while scanning a quoted scalar", + start_mark, "found unexpected end of stream") + return false + } + + // Consume non-blank characters. + leading_blanks := false + for !is_blankz(parser.buffer, parser.buffer_pos) { + if single && parser.buffer[parser.buffer_pos] == '\'' && parser.buffer[parser.buffer_pos+1] == '\'' { + // Is is an escaped single quote. + s = append(s, '\'') + skip(parser) + skip(parser) + + } else if single && parser.buffer[parser.buffer_pos] == '\'' { + // It is a right single quote. + break + } else if !single && parser.buffer[parser.buffer_pos] == '"' { + // It is a right double quote. + break + + } else if !single && parser.buffer[parser.buffer_pos] == '\\' && is_break(parser.buffer, parser.buffer_pos+1) { + // It is an escaped line break. + if parser.unread < 3 && !yaml_parser_update_buffer(parser, 3) { + return false + } + skip(parser) + skip_line(parser) + leading_blanks = true + break + + } else if !single && parser.buffer[parser.buffer_pos] == '\\' { + // It is an escape sequence. + code_length := 0 + + // Check the escape character. + switch parser.buffer[parser.buffer_pos+1] { + case '0': + s = append(s, 0) + case 'a': + s = append(s, '\x07') + case 'b': + s = append(s, '\x08') + case 't', '\t': + s = append(s, '\x09') + case 'n': + s = append(s, '\x0A') + case 'v': + s = append(s, '\x0B') + case 'f': + s = append(s, '\x0C') + case 'r': + s = append(s, '\x0D') + case 'e': + s = append(s, '\x1B') + case ' ': + s = append(s, '\x20') + case '"': + s = append(s, '"') + case '\'': + s = append(s, '\'') + case '\\': + s = append(s, '\\') + case 'N': // NEL (#x85) + s = append(s, '\xC2') + s = append(s, '\x85') + case '_': // #xA0 + s = append(s, '\xC2') + s = append(s, '\xA0') + case 'L': // LS (#x2028) + s = append(s, '\xE2') + s = append(s, '\x80') + s = append(s, '\xA8') + case 'P': // PS (#x2029) + s = append(s, '\xE2') + s = append(s, '\x80') + s = append(s, '\xA9') + case 'x': + code_length = 2 + case 'u': + code_length = 4 + case 'U': + code_length = 8 + default: + yaml_parser_set_scanner_error(parser, "while parsing a quoted scalar", + start_mark, "found unknown escape character") + return false + } + + skip(parser) + skip(parser) + + // Consume an arbitrary escape code. + if code_length > 0 { + var value int + + // Scan the character value. + if parser.unread < code_length && !yaml_parser_update_buffer(parser, code_length) { + return false + } + for k := 0; k < code_length; k++ { + if !is_hex(parser.buffer, parser.buffer_pos+k) { + yaml_parser_set_scanner_error(parser, "while parsing a quoted scalar", + start_mark, "did not find expected hexdecimal number") + return false + } + value = (value << 4) + as_hex(parser.buffer, parser.buffer_pos+k) + } + + // Check the value and write the character. + if (value >= 0xD800 && value <= 0xDFFF) || value > 0x10FFFF { + yaml_parser_set_scanner_error(parser, "while parsing a quoted scalar", + start_mark, "found invalid Unicode character escape code") + return false + } + if value <= 0x7F { + s = append(s, byte(value)) + } else if value <= 0x7FF { + s = append(s, byte(0xC0+(value>>6))) + s = append(s, byte(0x80+(value&0x3F))) + } else if value <= 0xFFFF { + s = append(s, byte(0xE0+(value>>12))) + s = append(s, byte(0x80+((value>>6)&0x3F))) + s = append(s, byte(0x80+(value&0x3F))) + } else { + s = append(s, byte(0xF0+(value>>18))) + s = append(s, byte(0x80+((value>>12)&0x3F))) + s = append(s, byte(0x80+((value>>6)&0x3F))) + s = append(s, byte(0x80+(value&0x3F))) + } + + // Advance the pointer. + for k := 0; k < code_length; k++ { + skip(parser) + } + } + } else { + // It is a non-escaped non-blank character. + s = read(parser, s) + } + if parser.unread < 2 && !yaml_parser_update_buffer(parser, 2) { + return false + } + } + + if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { + return false + } + + // Check if we are at the end of the scalar. + if single { + if parser.buffer[parser.buffer_pos] == '\'' { + break + } + } else { + if parser.buffer[parser.buffer_pos] == '"' { + break + } + } + + // Consume blank characters. + for is_blank(parser.buffer, parser.buffer_pos) || is_break(parser.buffer, parser.buffer_pos) { + if is_blank(parser.buffer, parser.buffer_pos) { + // Consume a space or a tab character. + if !leading_blanks { + whitespaces = read(parser, whitespaces) + } else { + skip(parser) + } + } else { + if parser.unread < 2 && !yaml_parser_update_buffer(parser, 2) { + return false + } + + // Check if it is a first line break. + if !leading_blanks { + whitespaces = whitespaces[:0] + leading_break = read_line(parser, leading_break) + leading_blanks = true + } else { + trailing_breaks = read_line(parser, trailing_breaks) + } + } + if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { + return false + } + } + + // Join the whitespaces or fold line breaks. + if leading_blanks { + // Do we need to fold line breaks? + if len(leading_break) > 0 && leading_break[0] == '\n' { + if len(trailing_breaks) == 0 { + s = append(s, ' ') + } else { + s = append(s, trailing_breaks...) + } + } else { + s = append(s, leading_break...) + s = append(s, trailing_breaks...) + } + trailing_breaks = trailing_breaks[:0] + leading_break = leading_break[:0] + } else { + s = append(s, whitespaces...) + whitespaces = whitespaces[:0] + } + } + + // Eat the right quote. + skip(parser) + end_mark := parser.mark + + // Create a token. + *token = yaml_token_t{ + typ: yaml_SCALAR_TOKEN, + start_mark: start_mark, + end_mark: end_mark, + value: s, + style: yaml_SINGLE_QUOTED_SCALAR_STYLE, + } + if !single { + token.style = yaml_DOUBLE_QUOTED_SCALAR_STYLE + } + return true +} + +// Scan a plain scalar. +func yaml_parser_scan_plain_scalar(parser *yaml_parser_t, token *yaml_token_t) bool { + + var s, leading_break, trailing_breaks, whitespaces []byte + var leading_blanks bool + var indent = parser.indent + 1 + + start_mark := parser.mark + end_mark := parser.mark + + // Consume the content of the plain scalar. + for { + // Check for a document indicator. + if parser.unread < 4 && !yaml_parser_update_buffer(parser, 4) { + return false + } + if parser.mark.column == 0 && + ((parser.buffer[parser.buffer_pos+0] == '-' && + parser.buffer[parser.buffer_pos+1] == '-' && + parser.buffer[parser.buffer_pos+2] == '-') || + (parser.buffer[parser.buffer_pos+0] == '.' && + parser.buffer[parser.buffer_pos+1] == '.' && + parser.buffer[parser.buffer_pos+2] == '.')) && + is_blankz(parser.buffer, parser.buffer_pos+3) { + break + } + + // Check for a comment. + if parser.buffer[parser.buffer_pos] == '#' { + break + } + + // Consume non-blank characters. + for !is_blankz(parser.buffer, parser.buffer_pos) { + + // Check for indicators that may end a plain scalar. + if (parser.buffer[parser.buffer_pos] == ':' && is_blankz(parser.buffer, parser.buffer_pos+1)) || + (parser.flow_level > 0 && + (parser.buffer[parser.buffer_pos] == ',' || + parser.buffer[parser.buffer_pos] == '?' || parser.buffer[parser.buffer_pos] == '[' || + parser.buffer[parser.buffer_pos] == ']' || parser.buffer[parser.buffer_pos] == '{' || + parser.buffer[parser.buffer_pos] == '}')) { + break + } + + // Check if we need to join whitespaces and breaks. + if leading_blanks || len(whitespaces) > 0 { + if leading_blanks { + // Do we need to fold line breaks? + if leading_break[0] == '\n' { + if len(trailing_breaks) == 0 { + s = append(s, ' ') + } else { + s = append(s, trailing_breaks...) + } + } else { + s = append(s, leading_break...) + s = append(s, trailing_breaks...) + } + trailing_breaks = trailing_breaks[:0] + leading_break = leading_break[:0] + leading_blanks = false + } else { + s = append(s, whitespaces...) + whitespaces = whitespaces[:0] + } + } + + // Copy the character. + s = read(parser, s) + + end_mark = parser.mark + if parser.unread < 2 && !yaml_parser_update_buffer(parser, 2) { + return false + } + } + + // Is it the end? + if !(is_blank(parser.buffer, parser.buffer_pos) || is_break(parser.buffer, parser.buffer_pos)) { + break + } + + // Consume blank characters. + if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { + return false + } + + for is_blank(parser.buffer, parser.buffer_pos) || is_break(parser.buffer, parser.buffer_pos) { + if is_blank(parser.buffer, parser.buffer_pos) { + + // Check for tab characters that abuse indentation. + if leading_blanks && parser.mark.column < indent && is_tab(parser.buffer, parser.buffer_pos) { + yaml_parser_set_scanner_error(parser, "while scanning a plain scalar", + start_mark, "found a tab character that violates indentation") + return false + } + + // Consume a space or a tab character. + if !leading_blanks { + whitespaces = read(parser, whitespaces) + } else { + skip(parser) + } + } else { + if parser.unread < 2 && !yaml_parser_update_buffer(parser, 2) { + return false + } + + // Check if it is a first line break. + if !leading_blanks { + whitespaces = whitespaces[:0] + leading_break = read_line(parser, leading_break) + leading_blanks = true + } else { + trailing_breaks = read_line(parser, trailing_breaks) + } + } + if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { + return false + } + } + + // Check indentation level. + if parser.flow_level == 0 && parser.mark.column < indent { + break + } + } + + // Create a token. + *token = yaml_token_t{ + typ: yaml_SCALAR_TOKEN, + start_mark: start_mark, + end_mark: end_mark, + value: s, + style: yaml_PLAIN_SCALAR_STYLE, + } + + // Note that we change the 'simple_key_allowed' flag. + if leading_blanks { + parser.simple_key_allowed = true + } + return true +} + +func yaml_parser_scan_line_comment(parser *yaml_parser_t, token_mark yaml_mark_t) bool { + if parser.newlines > 0 { + return true + } + + var start_mark yaml_mark_t + var text []byte + + for peek := 0; peek < 512; peek++ { + if parser.unread < peek+1 && !yaml_parser_update_buffer(parser, peek+1) { + break + } + if is_blank(parser.buffer, parser.buffer_pos+peek) { + continue + } + if parser.buffer[parser.buffer_pos+peek] == '#' { + seen := parser.mark.index+peek + for { + if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { + return false + } + if is_breakz(parser.buffer, parser.buffer_pos) { + if parser.mark.index >= seen { + break + } + if parser.unread < 2 && !yaml_parser_update_buffer(parser, 2) { + return false + } + skip_line(parser) + } else if parser.mark.index >= seen { + if len(text) == 0 { + start_mark = parser.mark + } + text = read(parser, text) + } else { + skip(parser) + } + } + } + break + } + if len(text) > 0 { + parser.comments = append(parser.comments, yaml_comment_t{ + token_mark: token_mark, + start_mark: start_mark, + line: text, + }) + } + return true +} + +func yaml_parser_scan_comments(parser *yaml_parser_t, scan_mark yaml_mark_t) bool { + token := parser.tokens[len(parser.tokens)-1] + + if token.typ == yaml_FLOW_ENTRY_TOKEN && len(parser.tokens) > 1 { + token = parser.tokens[len(parser.tokens)-2] + } + + var token_mark = token.start_mark + var start_mark yaml_mark_t + var next_indent = parser.indent + if next_indent < 0 { + next_indent = 0 + } + + var recent_empty = false + var first_empty = parser.newlines <= 1 + + var line = parser.mark.line + var column = parser.mark.column + + var text []byte + + // The foot line is the place where a comment must start to + // still be considered as a foot of the prior content. + // If there's some content in the currently parsed line, then + // the foot is the line below it. + var foot_line = -1 + if scan_mark.line > 0 { + foot_line = parser.mark.line-parser.newlines+1 + if parser.newlines == 0 && parser.mark.column > 1 { + foot_line++ + } + } + + var peek = 0 + for ; peek < 512; peek++ { + if parser.unread < peek+1 && !yaml_parser_update_buffer(parser, peek+1) { + break + } + column++ + if is_blank(parser.buffer, parser.buffer_pos+peek) { + continue + } + c := parser.buffer[parser.buffer_pos+peek] + var close_flow = parser.flow_level > 0 && (c == ']' || c == '}') + if close_flow || is_breakz(parser.buffer, parser.buffer_pos+peek) { + // Got line break or terminator. + if close_flow || !recent_empty { + if close_flow || first_empty && (start_mark.line == foot_line && token.typ != yaml_VALUE_TOKEN || start_mark.column-1 < next_indent) { + // This is the first empty line and there were no empty lines before, + // so this initial part of the comment is a foot of the prior token + // instead of being a head for the following one. Split it up. + // Alternatively, this might also be the last comment inside a flow + // scope, so it must be a footer. + if len(text) > 0 { + if start_mark.column-1 < next_indent { + // If dedented it's unrelated to the prior token. + token_mark = start_mark + } + parser.comments = append(parser.comments, yaml_comment_t{ + scan_mark: scan_mark, + token_mark: token_mark, + start_mark: start_mark, + end_mark: yaml_mark_t{parser.mark.index + peek, line, column}, + foot: text, + }) + scan_mark = yaml_mark_t{parser.mark.index + peek, line, column} + token_mark = scan_mark + text = nil + } + } else { + if len(text) > 0 && parser.buffer[parser.buffer_pos+peek] != 0 { + text = append(text, '\n') + } + } + } + if !is_break(parser.buffer, parser.buffer_pos+peek) { + break + } + first_empty = false + recent_empty = true + column = 0 + line++ + continue + } + + if len(text) > 0 && (close_flow || column-1 < next_indent && column != start_mark.column) { + // The comment at the different indentation is a foot of the + // preceding data rather than a head of the upcoming one. + parser.comments = append(parser.comments, yaml_comment_t{ + scan_mark: scan_mark, + token_mark: token_mark, + start_mark: start_mark, + end_mark: yaml_mark_t{parser.mark.index + peek, line, column}, + foot: text, + }) + scan_mark = yaml_mark_t{parser.mark.index + peek, line, column} + token_mark = scan_mark + text = nil + } + + if parser.buffer[parser.buffer_pos+peek] != '#' { + break + } + + if len(text) == 0 { + start_mark = yaml_mark_t{parser.mark.index + peek, line, column} + } else { + text = append(text, '\n') + } + + recent_empty = false + + // Consume until after the consumed comment line. + seen := parser.mark.index+peek + for { + if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { + return false + } + if is_breakz(parser.buffer, parser.buffer_pos) { + if parser.mark.index >= seen { + break + } + if parser.unread < 2 && !yaml_parser_update_buffer(parser, 2) { + return false + } + skip_line(parser) + } else if parser.mark.index >= seen { + text = read(parser, text) + } else { + skip(parser) + } + } + + peek = 0 + column = 0 + line = parser.mark.line + next_indent = parser.indent + if next_indent < 0 { + next_indent = 0 + } + } + + if len(text) > 0 { + parser.comments = append(parser.comments, yaml_comment_t{ + scan_mark: scan_mark, + token_mark: start_mark, + start_mark: start_mark, + end_mark: yaml_mark_t{parser.mark.index + peek - 1, line, column}, + head: text, + }) + } + return true +} diff --git a/vendor/gopkg.in/yaml.v3/sorter.go b/vendor/gopkg.in/yaml.v3/sorter.go new file mode 100644 index 000000000..9210ece7e --- /dev/null +++ b/vendor/gopkg.in/yaml.v3/sorter.go @@ -0,0 +1,134 @@ +// +// Copyright (c) 2011-2019 Canonical Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package yaml + +import ( + "reflect" + "unicode" +) + +type keyList []reflect.Value + +func (l keyList) Len() int { return len(l) } +func (l keyList) Swap(i, j int) { l[i], l[j] = l[j], l[i] } +func (l keyList) Less(i, j int) bool { + a := l[i] + b := l[j] + ak := a.Kind() + bk := b.Kind() + for (ak == reflect.Interface || ak == reflect.Ptr) && !a.IsNil() { + a = a.Elem() + ak = a.Kind() + } + for (bk == reflect.Interface || bk == reflect.Ptr) && !b.IsNil() { + b = b.Elem() + bk = b.Kind() + } + af, aok := keyFloat(a) + bf, bok := keyFloat(b) + if aok && bok { + if af != bf { + return af < bf + } + if ak != bk { + return ak < bk + } + return numLess(a, b) + } + if ak != reflect.String || bk != reflect.String { + return ak < bk + } + ar, br := []rune(a.String()), []rune(b.String()) + digits := false + for i := 0; i < len(ar) && i < len(br); i++ { + if ar[i] == br[i] { + digits = unicode.IsDigit(ar[i]) + continue + } + al := unicode.IsLetter(ar[i]) + bl := unicode.IsLetter(br[i]) + if al && bl { + return ar[i] < br[i] + } + if al || bl { + if digits { + return al + } else { + return bl + } + } + var ai, bi int + var an, bn int64 + if ar[i] == '0' || br[i] == '0' { + for j := i - 1; j >= 0 && unicode.IsDigit(ar[j]); j-- { + if ar[j] != '0' { + an = 1 + bn = 1 + break + } + } + } + for ai = i; ai < len(ar) && unicode.IsDigit(ar[ai]); ai++ { + an = an*10 + int64(ar[ai]-'0') + } + for bi = i; bi < len(br) && unicode.IsDigit(br[bi]); bi++ { + bn = bn*10 + int64(br[bi]-'0') + } + if an != bn { + return an < bn + } + if ai != bi { + return ai < bi + } + return ar[i] < br[i] + } + return len(ar) < len(br) +} + +// keyFloat returns a float value for v if it is a number/bool +// and whether it is a number/bool or not. +func keyFloat(v reflect.Value) (f float64, ok bool) { + switch v.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return float64(v.Int()), true + case reflect.Float32, reflect.Float64: + return v.Float(), true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return float64(v.Uint()), true + case reflect.Bool: + if v.Bool() { + return 1, true + } + return 0, true + } + return 0, false +} + +// numLess returns whether a < b. +// a and b must necessarily have the same kind. +func numLess(a, b reflect.Value) bool { + switch a.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return a.Int() < b.Int() + case reflect.Float32, reflect.Float64: + return a.Float() < b.Float() + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return a.Uint() < b.Uint() + case reflect.Bool: + return !a.Bool() && b.Bool() + } + panic("not a number") +} diff --git a/vendor/gopkg.in/yaml.v3/writerc.go b/vendor/gopkg.in/yaml.v3/writerc.go new file mode 100644 index 000000000..b8a116bf9 --- /dev/null +++ b/vendor/gopkg.in/yaml.v3/writerc.go @@ -0,0 +1,48 @@ +// +// Copyright (c) 2011-2019 Canonical Ltd +// Copyright (c) 2006-2010 Kirill Simonov +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of +// this software and associated documentation files (the "Software"), to deal in +// the Software without restriction, including without limitation the rights to +// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +// of the Software, and to permit persons to whom the Software is furnished to do +// so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package yaml + +// Set the writer error and return false. +func yaml_emitter_set_writer_error(emitter *yaml_emitter_t, problem string) bool { + emitter.error = yaml_WRITER_ERROR + emitter.problem = problem + return false +} + +// Flush the output buffer. +func yaml_emitter_flush(emitter *yaml_emitter_t) bool { + if emitter.write_handler == nil { + panic("write handler not set") + } + + // Check if the buffer is empty. + if emitter.buffer_pos == 0 { + return true + } + + if err := emitter.write_handler(emitter, emitter.buffer[:emitter.buffer_pos]); err != nil { + return yaml_emitter_set_writer_error(emitter, "write error: "+err.Error()) + } + emitter.buffer_pos = 0 + return true +} diff --git a/vendor/gopkg.in/yaml.v3/yaml.go b/vendor/gopkg.in/yaml.v3/yaml.go new file mode 100644 index 000000000..8cec6da48 --- /dev/null +++ b/vendor/gopkg.in/yaml.v3/yaml.go @@ -0,0 +1,698 @@ +// +// Copyright (c) 2011-2019 Canonical Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package yaml implements YAML support for the Go language. +// +// Source code and other details for the project are available at GitHub: +// +// https://github.com/go-yaml/yaml +// +package yaml + +import ( + "errors" + "fmt" + "io" + "reflect" + "strings" + "sync" + "unicode/utf8" +) + +// The Unmarshaler interface may be implemented by types to customize their +// behavior when being unmarshaled from a YAML document. +type Unmarshaler interface { + UnmarshalYAML(value *Node) error +} + +type obsoleteUnmarshaler interface { + UnmarshalYAML(unmarshal func(interface{}) error) error +} + +// The Marshaler interface may be implemented by types to customize their +// behavior when being marshaled into a YAML document. The returned value +// is marshaled in place of the original value implementing Marshaler. +// +// If an error is returned by MarshalYAML, the marshaling procedure stops +// and returns with the provided error. +type Marshaler interface { + MarshalYAML() (interface{}, error) +} + +// Unmarshal decodes the first document found within the in byte slice +// and assigns decoded values into the out value. +// +// Maps and pointers (to a struct, string, int, etc) are accepted as out +// values. If an internal pointer within a struct is not initialized, +// the yaml package will initialize it if necessary for unmarshalling +// the provided data. The out parameter must not be nil. +// +// The type of the decoded values should be compatible with the respective +// values in out. If one or more values cannot be decoded due to a type +// mismatches, decoding continues partially until the end of the YAML +// content, and a *yaml.TypeError is returned with details for all +// missed values. +// +// Struct fields are only unmarshalled if they are exported (have an +// upper case first letter), and are unmarshalled using the field name +// lowercased as the default key. Custom keys may be defined via the +// "yaml" name in the field tag: the content preceding the first comma +// is used as the key, and the following comma-separated options are +// used to tweak the marshalling process (see Marshal). +// Conflicting names result in a runtime error. +// +// For example: +// +// type T struct { +// F int `yaml:"a,omitempty"` +// B int +// } +// var t T +// yaml.Unmarshal([]byte("a: 1\nb: 2"), &t) +// +// See the documentation of Marshal for the format of tags and a list of +// supported tag options. +// +func Unmarshal(in []byte, out interface{}) (err error) { + return unmarshal(in, out, false) +} + +// A Decoder reads and decodes YAML values from an input stream. +type Decoder struct { + parser *parser + knownFields bool +} + +// NewDecoder returns a new decoder that reads from r. +// +// The decoder introduces its own buffering and may read +// data from r beyond the YAML values requested. +func NewDecoder(r io.Reader) *Decoder { + return &Decoder{ + parser: newParserFromReader(r), + } +} + +// KnownFields ensures that the keys in decoded mappings to +// exist as fields in the struct being decoded into. +func (dec *Decoder) KnownFields(enable bool) { + dec.knownFields = enable +} + +// Decode reads the next YAML-encoded value from its input +// and stores it in the value pointed to by v. +// +// See the documentation for Unmarshal for details about the +// conversion of YAML into a Go value. +func (dec *Decoder) Decode(v interface{}) (err error) { + d := newDecoder() + d.knownFields = dec.knownFields + defer handleErr(&err) + node := dec.parser.parse() + if node == nil { + return io.EOF + } + out := reflect.ValueOf(v) + if out.Kind() == reflect.Ptr && !out.IsNil() { + out = out.Elem() + } + d.unmarshal(node, out) + if len(d.terrors) > 0 { + return &TypeError{d.terrors} + } + return nil +} + +// Decode decodes the node and stores its data into the value pointed to by v. +// +// See the documentation for Unmarshal for details about the +// conversion of YAML into a Go value. +func (n *Node) Decode(v interface{}) (err error) { + d := newDecoder() + defer handleErr(&err) + out := reflect.ValueOf(v) + if out.Kind() == reflect.Ptr && !out.IsNil() { + out = out.Elem() + } + d.unmarshal(n, out) + if len(d.terrors) > 0 { + return &TypeError{d.terrors} + } + return nil +} + +func unmarshal(in []byte, out interface{}, strict bool) (err error) { + defer handleErr(&err) + d := newDecoder() + p := newParser(in) + defer p.destroy() + node := p.parse() + if node != nil { + v := reflect.ValueOf(out) + if v.Kind() == reflect.Ptr && !v.IsNil() { + v = v.Elem() + } + d.unmarshal(node, v) + } + if len(d.terrors) > 0 { + return &TypeError{d.terrors} + } + return nil +} + +// Marshal serializes the value provided into a YAML document. The structure +// of the generated document will reflect the structure of the value itself. +// Maps and pointers (to struct, string, int, etc) are accepted as the in value. +// +// Struct fields are only marshalled if they are exported (have an upper case +// first letter), and are marshalled using the field name lowercased as the +// default key. Custom keys may be defined via the "yaml" name in the field +// tag: the content preceding the first comma is used as the key, and the +// following comma-separated options are used to tweak the marshalling process. +// Conflicting names result in a runtime error. +// +// The field tag format accepted is: +// +// `(...) yaml:"[][,[,]]" (...)` +// +// The following flags are currently supported: +// +// omitempty Only include the field if it's not set to the zero +// value for the type or to empty slices or maps. +// Zero valued structs will be omitted if all their public +// fields are zero, unless they implement an IsZero +// method (see the IsZeroer interface type), in which +// case the field will be excluded if IsZero returns true. +// +// flow Marshal using a flow style (useful for structs, +// sequences and maps). +// +// inline Inline the field, which must be a struct or a map, +// causing all of its fields or keys to be processed as if +// they were part of the outer struct. For maps, keys must +// not conflict with the yaml keys of other struct fields. +// +// In addition, if the key is "-", the field is ignored. +// +// For example: +// +// type T struct { +// F int `yaml:"a,omitempty"` +// B int +// } +// yaml.Marshal(&T{B: 2}) // Returns "b: 2\n" +// yaml.Marshal(&T{F: 1}} // Returns "a: 1\nb: 0\n" +// +func Marshal(in interface{}) (out []byte, err error) { + defer handleErr(&err) + e := newEncoder() + defer e.destroy() + e.marshalDoc("", reflect.ValueOf(in)) + e.finish() + out = e.out + return +} + +// An Encoder writes YAML values to an output stream. +type Encoder struct { + encoder *encoder +} + +// NewEncoder returns a new encoder that writes to w. +// The Encoder should be closed after use to flush all data +// to w. +func NewEncoder(w io.Writer) *Encoder { + return &Encoder{ + encoder: newEncoderWithWriter(w), + } +} + +// Encode writes the YAML encoding of v to the stream. +// If multiple items are encoded to the stream, the +// second and subsequent document will be preceded +// with a "---" document separator, but the first will not. +// +// See the documentation for Marshal for details about the conversion of Go +// values to YAML. +func (e *Encoder) Encode(v interface{}) (err error) { + defer handleErr(&err) + e.encoder.marshalDoc("", reflect.ValueOf(v)) + return nil +} + +// Encode encodes value v and stores its representation in n. +// +// See the documentation for Marshal for details about the +// conversion of Go values into YAML. +func (n *Node) Encode(v interface{}) (err error) { + defer handleErr(&err) + e := newEncoder() + defer e.destroy() + e.marshalDoc("", reflect.ValueOf(v)) + e.finish() + p := newParser(e.out) + p.textless = true + defer p.destroy() + doc := p.parse() + *n = *doc.Content[0] + return nil +} + +// SetIndent changes the used indentation used when encoding. +func (e *Encoder) SetIndent(spaces int) { + if spaces < 0 { + panic("yaml: cannot indent to a negative number of spaces") + } + e.encoder.indent = spaces +} + +// Close closes the encoder by writing any remaining data. +// It does not write a stream terminating string "...". +func (e *Encoder) Close() (err error) { + defer handleErr(&err) + e.encoder.finish() + return nil +} + +func handleErr(err *error) { + if v := recover(); v != nil { + if e, ok := v.(yamlError); ok { + *err = e.err + } else { + panic(v) + } + } +} + +type yamlError struct { + err error +} + +func fail(err error) { + panic(yamlError{err}) +} + +func failf(format string, args ...interface{}) { + panic(yamlError{fmt.Errorf("yaml: "+format, args...)}) +} + +// A TypeError is returned by Unmarshal when one or more fields in +// the YAML document cannot be properly decoded into the requested +// types. When this error is returned, the value is still +// unmarshaled partially. +type TypeError struct { + Errors []string +} + +func (e *TypeError) Error() string { + return fmt.Sprintf("yaml: unmarshal errors:\n %s", strings.Join(e.Errors, "\n ")) +} + +type Kind uint32 + +const ( + DocumentNode Kind = 1 << iota + SequenceNode + MappingNode + ScalarNode + AliasNode +) + +type Style uint32 + +const ( + TaggedStyle Style = 1 << iota + DoubleQuotedStyle + SingleQuotedStyle + LiteralStyle + FoldedStyle + FlowStyle +) + +// Node represents an element in the YAML document hierarchy. While documents +// are typically encoded and decoded into higher level types, such as structs +// and maps, Node is an intermediate representation that allows detailed +// control over the content being decoded or encoded. +// +// It's worth noting that although Node offers access into details such as +// line numbers, colums, and comments, the content when re-encoded will not +// have its original textual representation preserved. An effort is made to +// render the data plesantly, and to preserve comments near the data they +// describe, though. +// +// Values that make use of the Node type interact with the yaml package in the +// same way any other type would do, by encoding and decoding yaml data +// directly or indirectly into them. +// +// For example: +// +// var person struct { +// Name string +// Address yaml.Node +// } +// err := yaml.Unmarshal(data, &person) +// +// Or by itself: +// +// var person Node +// err := yaml.Unmarshal(data, &person) +// +type Node struct { + // Kind defines whether the node is a document, a mapping, a sequence, + // a scalar value, or an alias to another node. The specific data type of + // scalar nodes may be obtained via the ShortTag and LongTag methods. + Kind Kind + + // Style allows customizing the apperance of the node in the tree. + Style Style + + // Tag holds the YAML tag defining the data type for the value. + // When decoding, this field will always be set to the resolved tag, + // even when it wasn't explicitly provided in the YAML content. + // When encoding, if this field is unset the value type will be + // implied from the node properties, and if it is set, it will only + // be serialized into the representation if TaggedStyle is used or + // the implicit tag diverges from the provided one. + Tag string + + // Value holds the unescaped and unquoted represenation of the value. + Value string + + // Anchor holds the anchor name for this node, which allows aliases to point to it. + Anchor string + + // Alias holds the node that this alias points to. Only valid when Kind is AliasNode. + Alias *Node + + // Content holds contained nodes for documents, mappings, and sequences. + Content []*Node + + // HeadComment holds any comments in the lines preceding the node and + // not separated by an empty line. + HeadComment string + + // LineComment holds any comments at the end of the line where the node is in. + LineComment string + + // FootComment holds any comments following the node and before empty lines. + FootComment string + + // Line and Column hold the node position in the decoded YAML text. + // These fields are not respected when encoding the node. + Line int + Column int +} + +// IsZero returns whether the node has all of its fields unset. +func (n *Node) IsZero() bool { + return n.Kind == 0 && n.Style == 0 && n.Tag == "" && n.Value == "" && n.Anchor == "" && n.Alias == nil && n.Content == nil && + n.HeadComment == "" && n.LineComment == "" && n.FootComment == "" && n.Line == 0 && n.Column == 0 +} + + +// LongTag returns the long form of the tag that indicates the data type for +// the node. If the Tag field isn't explicitly defined, one will be computed +// based on the node properties. +func (n *Node) LongTag() string { + return longTag(n.ShortTag()) +} + +// ShortTag returns the short form of the YAML tag that indicates data type for +// the node. If the Tag field isn't explicitly defined, one will be computed +// based on the node properties. +func (n *Node) ShortTag() string { + if n.indicatedString() { + return strTag + } + if n.Tag == "" || n.Tag == "!" { + switch n.Kind { + case MappingNode: + return mapTag + case SequenceNode: + return seqTag + case AliasNode: + if n.Alias != nil { + return n.Alias.ShortTag() + } + case ScalarNode: + tag, _ := resolve("", n.Value) + return tag + case 0: + // Special case to make the zero value convenient. + if n.IsZero() { + return nullTag + } + } + return "" + } + return shortTag(n.Tag) +} + +func (n *Node) indicatedString() bool { + return n.Kind == ScalarNode && + (shortTag(n.Tag) == strTag || + (n.Tag == "" || n.Tag == "!") && n.Style&(SingleQuotedStyle|DoubleQuotedStyle|LiteralStyle|FoldedStyle) != 0) +} + +// SetString is a convenience function that sets the node to a string value +// and defines its style in a pleasant way depending on its content. +func (n *Node) SetString(s string) { + n.Kind = ScalarNode + if utf8.ValidString(s) { + n.Value = s + n.Tag = strTag + } else { + n.Value = encodeBase64(s) + n.Tag = binaryTag + } + if strings.Contains(n.Value, "\n") { + n.Style = LiteralStyle + } +} + +// -------------------------------------------------------------------------- +// Maintain a mapping of keys to structure field indexes + +// The code in this section was copied from mgo/bson. + +// structInfo holds details for the serialization of fields of +// a given struct. +type structInfo struct { + FieldsMap map[string]fieldInfo + FieldsList []fieldInfo + + // InlineMap is the number of the field in the struct that + // contains an ,inline map, or -1 if there's none. + InlineMap int + + // InlineUnmarshalers holds indexes to inlined fields that + // contain unmarshaler values. + InlineUnmarshalers [][]int +} + +type fieldInfo struct { + Key string + Num int + OmitEmpty bool + Flow bool + // Id holds the unique field identifier, so we can cheaply + // check for field duplicates without maintaining an extra map. + Id int + + // Inline holds the field index if the field is part of an inlined struct. + Inline []int +} + +var structMap = make(map[reflect.Type]*structInfo) +var fieldMapMutex sync.RWMutex +var unmarshalerType reflect.Type + +func init() { + var v Unmarshaler + unmarshalerType = reflect.ValueOf(&v).Elem().Type() +} + +func getStructInfo(st reflect.Type) (*structInfo, error) { + fieldMapMutex.RLock() + sinfo, found := structMap[st] + fieldMapMutex.RUnlock() + if found { + return sinfo, nil + } + + n := st.NumField() + fieldsMap := make(map[string]fieldInfo) + fieldsList := make([]fieldInfo, 0, n) + inlineMap := -1 + inlineUnmarshalers := [][]int(nil) + for i := 0; i != n; i++ { + field := st.Field(i) + if field.PkgPath != "" && !field.Anonymous { + continue // Private field + } + + info := fieldInfo{Num: i} + + tag := field.Tag.Get("yaml") + if tag == "" && strings.Index(string(field.Tag), ":") < 0 { + tag = string(field.Tag) + } + if tag == "-" { + continue + } + + inline := false + fields := strings.Split(tag, ",") + if len(fields) > 1 { + for _, flag := range fields[1:] { + switch flag { + case "omitempty": + info.OmitEmpty = true + case "flow": + info.Flow = true + case "inline": + inline = true + default: + return nil, errors.New(fmt.Sprintf("unsupported flag %q in tag %q of type %s", flag, tag, st)) + } + } + tag = fields[0] + } + + if inline { + switch field.Type.Kind() { + case reflect.Map: + if inlineMap >= 0 { + return nil, errors.New("multiple ,inline maps in struct " + st.String()) + } + if field.Type.Key() != reflect.TypeOf("") { + return nil, errors.New("option ,inline needs a map with string keys in struct " + st.String()) + } + inlineMap = info.Num + case reflect.Struct, reflect.Ptr: + ftype := field.Type + for ftype.Kind() == reflect.Ptr { + ftype = ftype.Elem() + } + if ftype.Kind() != reflect.Struct { + return nil, errors.New("option ,inline may only be used on a struct or map field") + } + if reflect.PtrTo(ftype).Implements(unmarshalerType) { + inlineUnmarshalers = append(inlineUnmarshalers, []int{i}) + } else { + sinfo, err := getStructInfo(ftype) + if err != nil { + return nil, err + } + for _, index := range sinfo.InlineUnmarshalers { + inlineUnmarshalers = append(inlineUnmarshalers, append([]int{i}, index...)) + } + for _, finfo := range sinfo.FieldsList { + if _, found := fieldsMap[finfo.Key]; found { + msg := "duplicated key '" + finfo.Key + "' in struct " + st.String() + return nil, errors.New(msg) + } + if finfo.Inline == nil { + finfo.Inline = []int{i, finfo.Num} + } else { + finfo.Inline = append([]int{i}, finfo.Inline...) + } + finfo.Id = len(fieldsList) + fieldsMap[finfo.Key] = finfo + fieldsList = append(fieldsList, finfo) + } + } + default: + return nil, errors.New("option ,inline may only be used on a struct or map field") + } + continue + } + + if tag != "" { + info.Key = tag + } else { + info.Key = strings.ToLower(field.Name) + } + + if _, found = fieldsMap[info.Key]; found { + msg := "duplicated key '" + info.Key + "' in struct " + st.String() + return nil, errors.New(msg) + } + + info.Id = len(fieldsList) + fieldsList = append(fieldsList, info) + fieldsMap[info.Key] = info + } + + sinfo = &structInfo{ + FieldsMap: fieldsMap, + FieldsList: fieldsList, + InlineMap: inlineMap, + InlineUnmarshalers: inlineUnmarshalers, + } + + fieldMapMutex.Lock() + structMap[st] = sinfo + fieldMapMutex.Unlock() + return sinfo, nil +} + +// IsZeroer is used to check whether an object is zero to +// determine whether it should be omitted when marshaling +// with the omitempty flag. One notable implementation +// is time.Time. +type IsZeroer interface { + IsZero() bool +} + +func isZero(v reflect.Value) bool { + kind := v.Kind() + if z, ok := v.Interface().(IsZeroer); ok { + if (kind == reflect.Ptr || kind == reflect.Interface) && v.IsNil() { + return true + } + return z.IsZero() + } + switch kind { + case reflect.String: + return len(v.String()) == 0 + case reflect.Interface, reflect.Ptr: + return v.IsNil() + case reflect.Slice: + return v.Len() == 0 + case reflect.Map: + return v.Len() == 0 + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return v.Int() == 0 + case reflect.Float32, reflect.Float64: + return v.Float() == 0 + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return v.Uint() == 0 + case reflect.Bool: + return !v.Bool() + case reflect.Struct: + vt := v.Type() + for i := v.NumField() - 1; i >= 0; i-- { + if vt.Field(i).PkgPath != "" { + continue // Private field + } + if !isZero(v.Field(i)) { + return false + } + } + return true + } + return false +} diff --git a/vendor/gopkg.in/yaml.v3/yamlh.go b/vendor/gopkg.in/yaml.v3/yamlh.go new file mode 100644 index 000000000..7c6d00770 --- /dev/null +++ b/vendor/gopkg.in/yaml.v3/yamlh.go @@ -0,0 +1,807 @@ +// +// Copyright (c) 2011-2019 Canonical Ltd +// Copyright (c) 2006-2010 Kirill Simonov +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of +// this software and associated documentation files (the "Software"), to deal in +// the Software without restriction, including without limitation the rights to +// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +// of the Software, and to permit persons to whom the Software is furnished to do +// so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package yaml + +import ( + "fmt" + "io" +) + +// The version directive data. +type yaml_version_directive_t struct { + major int8 // The major version number. + minor int8 // The minor version number. +} + +// The tag directive data. +type yaml_tag_directive_t struct { + handle []byte // The tag handle. + prefix []byte // The tag prefix. +} + +type yaml_encoding_t int + +// The stream encoding. +const ( + // Let the parser choose the encoding. + yaml_ANY_ENCODING yaml_encoding_t = iota + + yaml_UTF8_ENCODING // The default UTF-8 encoding. + yaml_UTF16LE_ENCODING // The UTF-16-LE encoding with BOM. + yaml_UTF16BE_ENCODING // The UTF-16-BE encoding with BOM. +) + +type yaml_break_t int + +// Line break types. +const ( + // Let the parser choose the break type. + yaml_ANY_BREAK yaml_break_t = iota + + yaml_CR_BREAK // Use CR for line breaks (Mac style). + yaml_LN_BREAK // Use LN for line breaks (Unix style). + yaml_CRLN_BREAK // Use CR LN for line breaks (DOS style). +) + +type yaml_error_type_t int + +// Many bad things could happen with the parser and emitter. +const ( + // No error is produced. + yaml_NO_ERROR yaml_error_type_t = iota + + yaml_MEMORY_ERROR // Cannot allocate or reallocate a block of memory. + yaml_READER_ERROR // Cannot read or decode the input stream. + yaml_SCANNER_ERROR // Cannot scan the input stream. + yaml_PARSER_ERROR // Cannot parse the input stream. + yaml_COMPOSER_ERROR // Cannot compose a YAML document. + yaml_WRITER_ERROR // Cannot write to the output stream. + yaml_EMITTER_ERROR // Cannot emit a YAML stream. +) + +// The pointer position. +type yaml_mark_t struct { + index int // The position index. + line int // The position line. + column int // The position column. +} + +// Node Styles + +type yaml_style_t int8 + +type yaml_scalar_style_t yaml_style_t + +// Scalar styles. +const ( + // Let the emitter choose the style. + yaml_ANY_SCALAR_STYLE yaml_scalar_style_t = 0 + + yaml_PLAIN_SCALAR_STYLE yaml_scalar_style_t = 1 << iota // The plain scalar style. + yaml_SINGLE_QUOTED_SCALAR_STYLE // The single-quoted scalar style. + yaml_DOUBLE_QUOTED_SCALAR_STYLE // The double-quoted scalar style. + yaml_LITERAL_SCALAR_STYLE // The literal scalar style. + yaml_FOLDED_SCALAR_STYLE // The folded scalar style. +) + +type yaml_sequence_style_t yaml_style_t + +// Sequence styles. +const ( + // Let the emitter choose the style. + yaml_ANY_SEQUENCE_STYLE yaml_sequence_style_t = iota + + yaml_BLOCK_SEQUENCE_STYLE // The block sequence style. + yaml_FLOW_SEQUENCE_STYLE // The flow sequence style. +) + +type yaml_mapping_style_t yaml_style_t + +// Mapping styles. +const ( + // Let the emitter choose the style. + yaml_ANY_MAPPING_STYLE yaml_mapping_style_t = iota + + yaml_BLOCK_MAPPING_STYLE // The block mapping style. + yaml_FLOW_MAPPING_STYLE // The flow mapping style. +) + +// Tokens + +type yaml_token_type_t int + +// Token types. +const ( + // An empty token. + yaml_NO_TOKEN yaml_token_type_t = iota + + yaml_STREAM_START_TOKEN // A STREAM-START token. + yaml_STREAM_END_TOKEN // A STREAM-END token. + + yaml_VERSION_DIRECTIVE_TOKEN // A VERSION-DIRECTIVE token. + yaml_TAG_DIRECTIVE_TOKEN // A TAG-DIRECTIVE token. + yaml_DOCUMENT_START_TOKEN // A DOCUMENT-START token. + yaml_DOCUMENT_END_TOKEN // A DOCUMENT-END token. + + yaml_BLOCK_SEQUENCE_START_TOKEN // A BLOCK-SEQUENCE-START token. + yaml_BLOCK_MAPPING_START_TOKEN // A BLOCK-SEQUENCE-END token. + yaml_BLOCK_END_TOKEN // A BLOCK-END token. + + yaml_FLOW_SEQUENCE_START_TOKEN // A FLOW-SEQUENCE-START token. + yaml_FLOW_SEQUENCE_END_TOKEN // A FLOW-SEQUENCE-END token. + yaml_FLOW_MAPPING_START_TOKEN // A FLOW-MAPPING-START token. + yaml_FLOW_MAPPING_END_TOKEN // A FLOW-MAPPING-END token. + + yaml_BLOCK_ENTRY_TOKEN // A BLOCK-ENTRY token. + yaml_FLOW_ENTRY_TOKEN // A FLOW-ENTRY token. + yaml_KEY_TOKEN // A KEY token. + yaml_VALUE_TOKEN // A VALUE token. + + yaml_ALIAS_TOKEN // An ALIAS token. + yaml_ANCHOR_TOKEN // An ANCHOR token. + yaml_TAG_TOKEN // A TAG token. + yaml_SCALAR_TOKEN // A SCALAR token. +) + +func (tt yaml_token_type_t) String() string { + switch tt { + case yaml_NO_TOKEN: + return "yaml_NO_TOKEN" + case yaml_STREAM_START_TOKEN: + return "yaml_STREAM_START_TOKEN" + case yaml_STREAM_END_TOKEN: + return "yaml_STREAM_END_TOKEN" + case yaml_VERSION_DIRECTIVE_TOKEN: + return "yaml_VERSION_DIRECTIVE_TOKEN" + case yaml_TAG_DIRECTIVE_TOKEN: + return "yaml_TAG_DIRECTIVE_TOKEN" + case yaml_DOCUMENT_START_TOKEN: + return "yaml_DOCUMENT_START_TOKEN" + case yaml_DOCUMENT_END_TOKEN: + return "yaml_DOCUMENT_END_TOKEN" + case yaml_BLOCK_SEQUENCE_START_TOKEN: + return "yaml_BLOCK_SEQUENCE_START_TOKEN" + case yaml_BLOCK_MAPPING_START_TOKEN: + return "yaml_BLOCK_MAPPING_START_TOKEN" + case yaml_BLOCK_END_TOKEN: + return "yaml_BLOCK_END_TOKEN" + case yaml_FLOW_SEQUENCE_START_TOKEN: + return "yaml_FLOW_SEQUENCE_START_TOKEN" + case yaml_FLOW_SEQUENCE_END_TOKEN: + return "yaml_FLOW_SEQUENCE_END_TOKEN" + case yaml_FLOW_MAPPING_START_TOKEN: + return "yaml_FLOW_MAPPING_START_TOKEN" + case yaml_FLOW_MAPPING_END_TOKEN: + return "yaml_FLOW_MAPPING_END_TOKEN" + case yaml_BLOCK_ENTRY_TOKEN: + return "yaml_BLOCK_ENTRY_TOKEN" + case yaml_FLOW_ENTRY_TOKEN: + return "yaml_FLOW_ENTRY_TOKEN" + case yaml_KEY_TOKEN: + return "yaml_KEY_TOKEN" + case yaml_VALUE_TOKEN: + return "yaml_VALUE_TOKEN" + case yaml_ALIAS_TOKEN: + return "yaml_ALIAS_TOKEN" + case yaml_ANCHOR_TOKEN: + return "yaml_ANCHOR_TOKEN" + case yaml_TAG_TOKEN: + return "yaml_TAG_TOKEN" + case yaml_SCALAR_TOKEN: + return "yaml_SCALAR_TOKEN" + } + return "" +} + +// The token structure. +type yaml_token_t struct { + // The token type. + typ yaml_token_type_t + + // The start/end of the token. + start_mark, end_mark yaml_mark_t + + // The stream encoding (for yaml_STREAM_START_TOKEN). + encoding yaml_encoding_t + + // The alias/anchor/scalar value or tag/tag directive handle + // (for yaml_ALIAS_TOKEN, yaml_ANCHOR_TOKEN, yaml_SCALAR_TOKEN, yaml_TAG_TOKEN, yaml_TAG_DIRECTIVE_TOKEN). + value []byte + + // The tag suffix (for yaml_TAG_TOKEN). + suffix []byte + + // The tag directive prefix (for yaml_TAG_DIRECTIVE_TOKEN). + prefix []byte + + // The scalar style (for yaml_SCALAR_TOKEN). + style yaml_scalar_style_t + + // The version directive major/minor (for yaml_VERSION_DIRECTIVE_TOKEN). + major, minor int8 +} + +// Events + +type yaml_event_type_t int8 + +// Event types. +const ( + // An empty event. + yaml_NO_EVENT yaml_event_type_t = iota + + yaml_STREAM_START_EVENT // A STREAM-START event. + yaml_STREAM_END_EVENT // A STREAM-END event. + yaml_DOCUMENT_START_EVENT // A DOCUMENT-START event. + yaml_DOCUMENT_END_EVENT // A DOCUMENT-END event. + yaml_ALIAS_EVENT // An ALIAS event. + yaml_SCALAR_EVENT // A SCALAR event. + yaml_SEQUENCE_START_EVENT // A SEQUENCE-START event. + yaml_SEQUENCE_END_EVENT // A SEQUENCE-END event. + yaml_MAPPING_START_EVENT // A MAPPING-START event. + yaml_MAPPING_END_EVENT // A MAPPING-END event. + yaml_TAIL_COMMENT_EVENT +) + +var eventStrings = []string{ + yaml_NO_EVENT: "none", + yaml_STREAM_START_EVENT: "stream start", + yaml_STREAM_END_EVENT: "stream end", + yaml_DOCUMENT_START_EVENT: "document start", + yaml_DOCUMENT_END_EVENT: "document end", + yaml_ALIAS_EVENT: "alias", + yaml_SCALAR_EVENT: "scalar", + yaml_SEQUENCE_START_EVENT: "sequence start", + yaml_SEQUENCE_END_EVENT: "sequence end", + yaml_MAPPING_START_EVENT: "mapping start", + yaml_MAPPING_END_EVENT: "mapping end", + yaml_TAIL_COMMENT_EVENT: "tail comment", +} + +func (e yaml_event_type_t) String() string { + if e < 0 || int(e) >= len(eventStrings) { + return fmt.Sprintf("unknown event %d", e) + } + return eventStrings[e] +} + +// The event structure. +type yaml_event_t struct { + + // The event type. + typ yaml_event_type_t + + // The start and end of the event. + start_mark, end_mark yaml_mark_t + + // The document encoding (for yaml_STREAM_START_EVENT). + encoding yaml_encoding_t + + // The version directive (for yaml_DOCUMENT_START_EVENT). + version_directive *yaml_version_directive_t + + // The list of tag directives (for yaml_DOCUMENT_START_EVENT). + tag_directives []yaml_tag_directive_t + + // The comments + head_comment []byte + line_comment []byte + foot_comment []byte + tail_comment []byte + + // The anchor (for yaml_SCALAR_EVENT, yaml_SEQUENCE_START_EVENT, yaml_MAPPING_START_EVENT, yaml_ALIAS_EVENT). + anchor []byte + + // The tag (for yaml_SCALAR_EVENT, yaml_SEQUENCE_START_EVENT, yaml_MAPPING_START_EVENT). + tag []byte + + // The scalar value (for yaml_SCALAR_EVENT). + value []byte + + // Is the document start/end indicator implicit, or the tag optional? + // (for yaml_DOCUMENT_START_EVENT, yaml_DOCUMENT_END_EVENT, yaml_SEQUENCE_START_EVENT, yaml_MAPPING_START_EVENT, yaml_SCALAR_EVENT). + implicit bool + + // Is the tag optional for any non-plain style? (for yaml_SCALAR_EVENT). + quoted_implicit bool + + // The style (for yaml_SCALAR_EVENT, yaml_SEQUENCE_START_EVENT, yaml_MAPPING_START_EVENT). + style yaml_style_t +} + +func (e *yaml_event_t) scalar_style() yaml_scalar_style_t { return yaml_scalar_style_t(e.style) } +func (e *yaml_event_t) sequence_style() yaml_sequence_style_t { return yaml_sequence_style_t(e.style) } +func (e *yaml_event_t) mapping_style() yaml_mapping_style_t { return yaml_mapping_style_t(e.style) } + +// Nodes + +const ( + yaml_NULL_TAG = "tag:yaml.org,2002:null" // The tag !!null with the only possible value: null. + yaml_BOOL_TAG = "tag:yaml.org,2002:bool" // The tag !!bool with the values: true and false. + yaml_STR_TAG = "tag:yaml.org,2002:str" // The tag !!str for string values. + yaml_INT_TAG = "tag:yaml.org,2002:int" // The tag !!int for integer values. + yaml_FLOAT_TAG = "tag:yaml.org,2002:float" // The tag !!float for float values. + yaml_TIMESTAMP_TAG = "tag:yaml.org,2002:timestamp" // The tag !!timestamp for date and time values. + + yaml_SEQ_TAG = "tag:yaml.org,2002:seq" // The tag !!seq is used to denote sequences. + yaml_MAP_TAG = "tag:yaml.org,2002:map" // The tag !!map is used to denote mapping. + + // Not in original libyaml. + yaml_BINARY_TAG = "tag:yaml.org,2002:binary" + yaml_MERGE_TAG = "tag:yaml.org,2002:merge" + + yaml_DEFAULT_SCALAR_TAG = yaml_STR_TAG // The default scalar tag is !!str. + yaml_DEFAULT_SEQUENCE_TAG = yaml_SEQ_TAG // The default sequence tag is !!seq. + yaml_DEFAULT_MAPPING_TAG = yaml_MAP_TAG // The default mapping tag is !!map. +) + +type yaml_node_type_t int + +// Node types. +const ( + // An empty node. + yaml_NO_NODE yaml_node_type_t = iota + + yaml_SCALAR_NODE // A scalar node. + yaml_SEQUENCE_NODE // A sequence node. + yaml_MAPPING_NODE // A mapping node. +) + +// An element of a sequence node. +type yaml_node_item_t int + +// An element of a mapping node. +type yaml_node_pair_t struct { + key int // The key of the element. + value int // The value of the element. +} + +// The node structure. +type yaml_node_t struct { + typ yaml_node_type_t // The node type. + tag []byte // The node tag. + + // The node data. + + // The scalar parameters (for yaml_SCALAR_NODE). + scalar struct { + value []byte // The scalar value. + length int // The length of the scalar value. + style yaml_scalar_style_t // The scalar style. + } + + // The sequence parameters (for YAML_SEQUENCE_NODE). + sequence struct { + items_data []yaml_node_item_t // The stack of sequence items. + style yaml_sequence_style_t // The sequence style. + } + + // The mapping parameters (for yaml_MAPPING_NODE). + mapping struct { + pairs_data []yaml_node_pair_t // The stack of mapping pairs (key, value). + pairs_start *yaml_node_pair_t // The beginning of the stack. + pairs_end *yaml_node_pair_t // The end of the stack. + pairs_top *yaml_node_pair_t // The top of the stack. + style yaml_mapping_style_t // The mapping style. + } + + start_mark yaml_mark_t // The beginning of the node. + end_mark yaml_mark_t // The end of the node. + +} + +// The document structure. +type yaml_document_t struct { + + // The document nodes. + nodes []yaml_node_t + + // The version directive. + version_directive *yaml_version_directive_t + + // The list of tag directives. + tag_directives_data []yaml_tag_directive_t + tag_directives_start int // The beginning of the tag directives list. + tag_directives_end int // The end of the tag directives list. + + start_implicit int // Is the document start indicator implicit? + end_implicit int // Is the document end indicator implicit? + + // The start/end of the document. + start_mark, end_mark yaml_mark_t +} + +// The prototype of a read handler. +// +// The read handler is called when the parser needs to read more bytes from the +// source. The handler should write not more than size bytes to the buffer. +// The number of written bytes should be set to the size_read variable. +// +// [in,out] data A pointer to an application data specified by +// yaml_parser_set_input(). +// [out] buffer The buffer to write the data from the source. +// [in] size The size of the buffer. +// [out] size_read The actual number of bytes read from the source. +// +// On success, the handler should return 1. If the handler failed, +// the returned value should be 0. On EOF, the handler should set the +// size_read to 0 and return 1. +type yaml_read_handler_t func(parser *yaml_parser_t, buffer []byte) (n int, err error) + +// This structure holds information about a potential simple key. +type yaml_simple_key_t struct { + possible bool // Is a simple key possible? + required bool // Is a simple key required? + token_number int // The number of the token. + mark yaml_mark_t // The position mark. +} + +// The states of the parser. +type yaml_parser_state_t int + +const ( + yaml_PARSE_STREAM_START_STATE yaml_parser_state_t = iota + + yaml_PARSE_IMPLICIT_DOCUMENT_START_STATE // Expect the beginning of an implicit document. + yaml_PARSE_DOCUMENT_START_STATE // Expect DOCUMENT-START. + yaml_PARSE_DOCUMENT_CONTENT_STATE // Expect the content of a document. + yaml_PARSE_DOCUMENT_END_STATE // Expect DOCUMENT-END. + yaml_PARSE_BLOCK_NODE_STATE // Expect a block node. + yaml_PARSE_BLOCK_NODE_OR_INDENTLESS_SEQUENCE_STATE // Expect a block node or indentless sequence. + yaml_PARSE_FLOW_NODE_STATE // Expect a flow node. + yaml_PARSE_BLOCK_SEQUENCE_FIRST_ENTRY_STATE // Expect the first entry of a block sequence. + yaml_PARSE_BLOCK_SEQUENCE_ENTRY_STATE // Expect an entry of a block sequence. + yaml_PARSE_INDENTLESS_SEQUENCE_ENTRY_STATE // Expect an entry of an indentless sequence. + yaml_PARSE_BLOCK_MAPPING_FIRST_KEY_STATE // Expect the first key of a block mapping. + yaml_PARSE_BLOCK_MAPPING_KEY_STATE // Expect a block mapping key. + yaml_PARSE_BLOCK_MAPPING_VALUE_STATE // Expect a block mapping value. + yaml_PARSE_FLOW_SEQUENCE_FIRST_ENTRY_STATE // Expect the first entry of a flow sequence. + yaml_PARSE_FLOW_SEQUENCE_ENTRY_STATE // Expect an entry of a flow sequence. + yaml_PARSE_FLOW_SEQUENCE_ENTRY_MAPPING_KEY_STATE // Expect a key of an ordered mapping. + yaml_PARSE_FLOW_SEQUENCE_ENTRY_MAPPING_VALUE_STATE // Expect a value of an ordered mapping. + yaml_PARSE_FLOW_SEQUENCE_ENTRY_MAPPING_END_STATE // Expect the and of an ordered mapping entry. + yaml_PARSE_FLOW_MAPPING_FIRST_KEY_STATE // Expect the first key of a flow mapping. + yaml_PARSE_FLOW_MAPPING_KEY_STATE // Expect a key of a flow mapping. + yaml_PARSE_FLOW_MAPPING_VALUE_STATE // Expect a value of a flow mapping. + yaml_PARSE_FLOW_MAPPING_EMPTY_VALUE_STATE // Expect an empty value of a flow mapping. + yaml_PARSE_END_STATE // Expect nothing. +) + +func (ps yaml_parser_state_t) String() string { + switch ps { + case yaml_PARSE_STREAM_START_STATE: + return "yaml_PARSE_STREAM_START_STATE" + case yaml_PARSE_IMPLICIT_DOCUMENT_START_STATE: + return "yaml_PARSE_IMPLICIT_DOCUMENT_START_STATE" + case yaml_PARSE_DOCUMENT_START_STATE: + return "yaml_PARSE_DOCUMENT_START_STATE" + case yaml_PARSE_DOCUMENT_CONTENT_STATE: + return "yaml_PARSE_DOCUMENT_CONTENT_STATE" + case yaml_PARSE_DOCUMENT_END_STATE: + return "yaml_PARSE_DOCUMENT_END_STATE" + case yaml_PARSE_BLOCK_NODE_STATE: + return "yaml_PARSE_BLOCK_NODE_STATE" + case yaml_PARSE_BLOCK_NODE_OR_INDENTLESS_SEQUENCE_STATE: + return "yaml_PARSE_BLOCK_NODE_OR_INDENTLESS_SEQUENCE_STATE" + case yaml_PARSE_FLOW_NODE_STATE: + return "yaml_PARSE_FLOW_NODE_STATE" + case yaml_PARSE_BLOCK_SEQUENCE_FIRST_ENTRY_STATE: + return "yaml_PARSE_BLOCK_SEQUENCE_FIRST_ENTRY_STATE" + case yaml_PARSE_BLOCK_SEQUENCE_ENTRY_STATE: + return "yaml_PARSE_BLOCK_SEQUENCE_ENTRY_STATE" + case yaml_PARSE_INDENTLESS_SEQUENCE_ENTRY_STATE: + return "yaml_PARSE_INDENTLESS_SEQUENCE_ENTRY_STATE" + case yaml_PARSE_BLOCK_MAPPING_FIRST_KEY_STATE: + return "yaml_PARSE_BLOCK_MAPPING_FIRST_KEY_STATE" + case yaml_PARSE_BLOCK_MAPPING_KEY_STATE: + return "yaml_PARSE_BLOCK_MAPPING_KEY_STATE" + case yaml_PARSE_BLOCK_MAPPING_VALUE_STATE: + return "yaml_PARSE_BLOCK_MAPPING_VALUE_STATE" + case yaml_PARSE_FLOW_SEQUENCE_FIRST_ENTRY_STATE: + return "yaml_PARSE_FLOW_SEQUENCE_FIRST_ENTRY_STATE" + case yaml_PARSE_FLOW_SEQUENCE_ENTRY_STATE: + return "yaml_PARSE_FLOW_SEQUENCE_ENTRY_STATE" + case yaml_PARSE_FLOW_SEQUENCE_ENTRY_MAPPING_KEY_STATE: + return "yaml_PARSE_FLOW_SEQUENCE_ENTRY_MAPPING_KEY_STATE" + case yaml_PARSE_FLOW_SEQUENCE_ENTRY_MAPPING_VALUE_STATE: + return "yaml_PARSE_FLOW_SEQUENCE_ENTRY_MAPPING_VALUE_STATE" + case yaml_PARSE_FLOW_SEQUENCE_ENTRY_MAPPING_END_STATE: + return "yaml_PARSE_FLOW_SEQUENCE_ENTRY_MAPPING_END_STATE" + case yaml_PARSE_FLOW_MAPPING_FIRST_KEY_STATE: + return "yaml_PARSE_FLOW_MAPPING_FIRST_KEY_STATE" + case yaml_PARSE_FLOW_MAPPING_KEY_STATE: + return "yaml_PARSE_FLOW_MAPPING_KEY_STATE" + case yaml_PARSE_FLOW_MAPPING_VALUE_STATE: + return "yaml_PARSE_FLOW_MAPPING_VALUE_STATE" + case yaml_PARSE_FLOW_MAPPING_EMPTY_VALUE_STATE: + return "yaml_PARSE_FLOW_MAPPING_EMPTY_VALUE_STATE" + case yaml_PARSE_END_STATE: + return "yaml_PARSE_END_STATE" + } + return "" +} + +// This structure holds aliases data. +type yaml_alias_data_t struct { + anchor []byte // The anchor. + index int // The node id. + mark yaml_mark_t // The anchor mark. +} + +// The parser structure. +// +// All members are internal. Manage the structure using the +// yaml_parser_ family of functions. +type yaml_parser_t struct { + + // Error handling + + error yaml_error_type_t // Error type. + + problem string // Error description. + + // The byte about which the problem occurred. + problem_offset int + problem_value int + problem_mark yaml_mark_t + + // The error context. + context string + context_mark yaml_mark_t + + // Reader stuff + + read_handler yaml_read_handler_t // Read handler. + + input_reader io.Reader // File input data. + input []byte // String input data. + input_pos int + + eof bool // EOF flag + + buffer []byte // The working buffer. + buffer_pos int // The current position of the buffer. + + unread int // The number of unread characters in the buffer. + + newlines int // The number of line breaks since last non-break/non-blank character + + raw_buffer []byte // The raw buffer. + raw_buffer_pos int // The current position of the buffer. + + encoding yaml_encoding_t // The input encoding. + + offset int // The offset of the current position (in bytes). + mark yaml_mark_t // The mark of the current position. + + // Comments + + head_comment []byte // The current head comments + line_comment []byte // The current line comments + foot_comment []byte // The current foot comments + tail_comment []byte // Foot comment that happens at the end of a block. + stem_comment []byte // Comment in item preceding a nested structure (list inside list item, etc) + + comments []yaml_comment_t // The folded comments for all parsed tokens + comments_head int + + // Scanner stuff + + stream_start_produced bool // Have we started to scan the input stream? + stream_end_produced bool // Have we reached the end of the input stream? + + flow_level int // The number of unclosed '[' and '{' indicators. + + tokens []yaml_token_t // The tokens queue. + tokens_head int // The head of the tokens queue. + tokens_parsed int // The number of tokens fetched from the queue. + token_available bool // Does the tokens queue contain a token ready for dequeueing. + + indent int // The current indentation level. + indents []int // The indentation levels stack. + + simple_key_allowed bool // May a simple key occur at the current position? + simple_keys []yaml_simple_key_t // The stack of simple keys. + simple_keys_by_tok map[int]int // possible simple_key indexes indexed by token_number + + // Parser stuff + + state yaml_parser_state_t // The current parser state. + states []yaml_parser_state_t // The parser states stack. + marks []yaml_mark_t // The stack of marks. + tag_directives []yaml_tag_directive_t // The list of TAG directives. + + // Dumper stuff + + aliases []yaml_alias_data_t // The alias data. + + document *yaml_document_t // The currently parsed document. +} + +type yaml_comment_t struct { + + scan_mark yaml_mark_t // Position where scanning for comments started + token_mark yaml_mark_t // Position after which tokens will be associated with this comment + start_mark yaml_mark_t // Position of '#' comment mark + end_mark yaml_mark_t // Position where comment terminated + + head []byte + line []byte + foot []byte +} + +// Emitter Definitions + +// The prototype of a write handler. +// +// The write handler is called when the emitter needs to flush the accumulated +// characters to the output. The handler should write @a size bytes of the +// @a buffer to the output. +// +// @param[in,out] data A pointer to an application data specified by +// yaml_emitter_set_output(). +// @param[in] buffer The buffer with bytes to be written. +// @param[in] size The size of the buffer. +// +// @returns On success, the handler should return @c 1. If the handler failed, +// the returned value should be @c 0. +// +type yaml_write_handler_t func(emitter *yaml_emitter_t, buffer []byte) error + +type yaml_emitter_state_t int + +// The emitter states. +const ( + // Expect STREAM-START. + yaml_EMIT_STREAM_START_STATE yaml_emitter_state_t = iota + + yaml_EMIT_FIRST_DOCUMENT_START_STATE // Expect the first DOCUMENT-START or STREAM-END. + yaml_EMIT_DOCUMENT_START_STATE // Expect DOCUMENT-START or STREAM-END. + yaml_EMIT_DOCUMENT_CONTENT_STATE // Expect the content of a document. + yaml_EMIT_DOCUMENT_END_STATE // Expect DOCUMENT-END. + yaml_EMIT_FLOW_SEQUENCE_FIRST_ITEM_STATE // Expect the first item of a flow sequence. + yaml_EMIT_FLOW_SEQUENCE_TRAIL_ITEM_STATE // Expect the next item of a flow sequence, with the comma already written out + yaml_EMIT_FLOW_SEQUENCE_ITEM_STATE // Expect an item of a flow sequence. + yaml_EMIT_FLOW_MAPPING_FIRST_KEY_STATE // Expect the first key of a flow mapping. + yaml_EMIT_FLOW_MAPPING_TRAIL_KEY_STATE // Expect the next key of a flow mapping, with the comma already written out + yaml_EMIT_FLOW_MAPPING_KEY_STATE // Expect a key of a flow mapping. + yaml_EMIT_FLOW_MAPPING_SIMPLE_VALUE_STATE // Expect a value for a simple key of a flow mapping. + yaml_EMIT_FLOW_MAPPING_VALUE_STATE // Expect a value of a flow mapping. + yaml_EMIT_BLOCK_SEQUENCE_FIRST_ITEM_STATE // Expect the first item of a block sequence. + yaml_EMIT_BLOCK_SEQUENCE_ITEM_STATE // Expect an item of a block sequence. + yaml_EMIT_BLOCK_MAPPING_FIRST_KEY_STATE // Expect the first key of a block mapping. + yaml_EMIT_BLOCK_MAPPING_KEY_STATE // Expect the key of a block mapping. + yaml_EMIT_BLOCK_MAPPING_SIMPLE_VALUE_STATE // Expect a value for a simple key of a block mapping. + yaml_EMIT_BLOCK_MAPPING_VALUE_STATE // Expect a value of a block mapping. + yaml_EMIT_END_STATE // Expect nothing. +) + +// The emitter structure. +// +// All members are internal. Manage the structure using the @c yaml_emitter_ +// family of functions. +type yaml_emitter_t struct { + + // Error handling + + error yaml_error_type_t // Error type. + problem string // Error description. + + // Writer stuff + + write_handler yaml_write_handler_t // Write handler. + + output_buffer *[]byte // String output data. + output_writer io.Writer // File output data. + + buffer []byte // The working buffer. + buffer_pos int // The current position of the buffer. + + raw_buffer []byte // The raw buffer. + raw_buffer_pos int // The current position of the buffer. + + encoding yaml_encoding_t // The stream encoding. + + // Emitter stuff + + canonical bool // If the output is in the canonical style? + best_indent int // The number of indentation spaces. + best_width int // The preferred width of the output lines. + unicode bool // Allow unescaped non-ASCII characters? + line_break yaml_break_t // The preferred line break. + + state yaml_emitter_state_t // The current emitter state. + states []yaml_emitter_state_t // The stack of states. + + events []yaml_event_t // The event queue. + events_head int // The head of the event queue. + + indents []int // The stack of indentation levels. + + tag_directives []yaml_tag_directive_t // The list of tag directives. + + indent int // The current indentation level. + + flow_level int // The current flow level. + + root_context bool // Is it the document root context? + sequence_context bool // Is it a sequence context? + mapping_context bool // Is it a mapping context? + simple_key_context bool // Is it a simple mapping key context? + + line int // The current line. + column int // The current column. + whitespace bool // If the last character was a whitespace? + indention bool // If the last character was an indentation character (' ', '-', '?', ':')? + open_ended bool // If an explicit document end is required? + + space_above bool // Is there's an empty line above? + foot_indent int // The indent used to write the foot comment above, or -1 if none. + + // Anchor analysis. + anchor_data struct { + anchor []byte // The anchor value. + alias bool // Is it an alias? + } + + // Tag analysis. + tag_data struct { + handle []byte // The tag handle. + suffix []byte // The tag suffix. + } + + // Scalar analysis. + scalar_data struct { + value []byte // The scalar value. + multiline bool // Does the scalar contain line breaks? + flow_plain_allowed bool // Can the scalar be expessed in the flow plain style? + block_plain_allowed bool // Can the scalar be expressed in the block plain style? + single_quoted_allowed bool // Can the scalar be expressed in the single quoted style? + block_allowed bool // Can the scalar be expressed in the literal or folded styles? + style yaml_scalar_style_t // The output style. + } + + // Comments + head_comment []byte + line_comment []byte + foot_comment []byte + tail_comment []byte + + key_line_comment []byte + + // Dumper stuff + + opened bool // If the stream was already opened? + closed bool // If the stream was already closed? + + // The information associated with the document nodes. + anchors *struct { + references int // The number of references. + anchor int // The anchor id. + serialized bool // If the node has been emitted? + } + + last_anchor_id int // The last assigned anchor id. + + document *yaml_document_t // The currently emitted document. +} diff --git a/vendor/gopkg.in/yaml.v3/yamlprivateh.go b/vendor/gopkg.in/yaml.v3/yamlprivateh.go new file mode 100644 index 000000000..e88f9c54a --- /dev/null +++ b/vendor/gopkg.in/yaml.v3/yamlprivateh.go @@ -0,0 +1,198 @@ +// +// Copyright (c) 2011-2019 Canonical Ltd +// Copyright (c) 2006-2010 Kirill Simonov +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of +// this software and associated documentation files (the "Software"), to deal in +// the Software without restriction, including without limitation the rights to +// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +// of the Software, and to permit persons to whom the Software is furnished to do +// so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package yaml + +const ( + // The size of the input raw buffer. + input_raw_buffer_size = 512 + + // The size of the input buffer. + // It should be possible to decode the whole raw buffer. + input_buffer_size = input_raw_buffer_size * 3 + + // The size of the output buffer. + output_buffer_size = 128 + + // The size of the output raw buffer. + // It should be possible to encode the whole output buffer. + output_raw_buffer_size = (output_buffer_size*2 + 2) + + // The size of other stacks and queues. + initial_stack_size = 16 + initial_queue_size = 16 + initial_string_size = 16 +) + +// Check if the character at the specified position is an alphabetical +// character, a digit, '_', or '-'. +func is_alpha(b []byte, i int) bool { + return b[i] >= '0' && b[i] <= '9' || b[i] >= 'A' && b[i] <= 'Z' || b[i] >= 'a' && b[i] <= 'z' || b[i] == '_' || b[i] == '-' +} + +// Check if the character at the specified position is a digit. +func is_digit(b []byte, i int) bool { + return b[i] >= '0' && b[i] <= '9' +} + +// Get the value of a digit. +func as_digit(b []byte, i int) int { + return int(b[i]) - '0' +} + +// Check if the character at the specified position is a hex-digit. +func is_hex(b []byte, i int) bool { + return b[i] >= '0' && b[i] <= '9' || b[i] >= 'A' && b[i] <= 'F' || b[i] >= 'a' && b[i] <= 'f' +} + +// Get the value of a hex-digit. +func as_hex(b []byte, i int) int { + bi := b[i] + if bi >= 'A' && bi <= 'F' { + return int(bi) - 'A' + 10 + } + if bi >= 'a' && bi <= 'f' { + return int(bi) - 'a' + 10 + } + return int(bi) - '0' +} + +// Check if the character is ASCII. +func is_ascii(b []byte, i int) bool { + return b[i] <= 0x7F +} + +// Check if the character at the start of the buffer can be printed unescaped. +func is_printable(b []byte, i int) bool { + return ((b[i] == 0x0A) || // . == #x0A + (b[i] >= 0x20 && b[i] <= 0x7E) || // #x20 <= . <= #x7E + (b[i] == 0xC2 && b[i+1] >= 0xA0) || // #0xA0 <= . <= #xD7FF + (b[i] > 0xC2 && b[i] < 0xED) || + (b[i] == 0xED && b[i+1] < 0xA0) || + (b[i] == 0xEE) || + (b[i] == 0xEF && // #xE000 <= . <= #xFFFD + !(b[i+1] == 0xBB && b[i+2] == 0xBF) && // && . != #xFEFF + !(b[i+1] == 0xBF && (b[i+2] == 0xBE || b[i+2] == 0xBF)))) +} + +// Check if the character at the specified position is NUL. +func is_z(b []byte, i int) bool { + return b[i] == 0x00 +} + +// Check if the beginning of the buffer is a BOM. +func is_bom(b []byte, i int) bool { + return b[0] == 0xEF && b[1] == 0xBB && b[2] == 0xBF +} + +// Check if the character at the specified position is space. +func is_space(b []byte, i int) bool { + return b[i] == ' ' +} + +// Check if the character at the specified position is tab. +func is_tab(b []byte, i int) bool { + return b[i] == '\t' +} + +// Check if the character at the specified position is blank (space or tab). +func is_blank(b []byte, i int) bool { + //return is_space(b, i) || is_tab(b, i) + return b[i] == ' ' || b[i] == '\t' +} + +// Check if the character at the specified position is a line break. +func is_break(b []byte, i int) bool { + return (b[i] == '\r' || // CR (#xD) + b[i] == '\n' || // LF (#xA) + b[i] == 0xC2 && b[i+1] == 0x85 || // NEL (#x85) + b[i] == 0xE2 && b[i+1] == 0x80 && b[i+2] == 0xA8 || // LS (#x2028) + b[i] == 0xE2 && b[i+1] == 0x80 && b[i+2] == 0xA9) // PS (#x2029) +} + +func is_crlf(b []byte, i int) bool { + return b[i] == '\r' && b[i+1] == '\n' +} + +// Check if the character is a line break or NUL. +func is_breakz(b []byte, i int) bool { + //return is_break(b, i) || is_z(b, i) + return ( + // is_break: + b[i] == '\r' || // CR (#xD) + b[i] == '\n' || // LF (#xA) + b[i] == 0xC2 && b[i+1] == 0x85 || // NEL (#x85) + b[i] == 0xE2 && b[i+1] == 0x80 && b[i+2] == 0xA8 || // LS (#x2028) + b[i] == 0xE2 && b[i+1] == 0x80 && b[i+2] == 0xA9 || // PS (#x2029) + // is_z: + b[i] == 0) +} + +// Check if the character is a line break, space, or NUL. +func is_spacez(b []byte, i int) bool { + //return is_space(b, i) || is_breakz(b, i) + return ( + // is_space: + b[i] == ' ' || + // is_breakz: + b[i] == '\r' || // CR (#xD) + b[i] == '\n' || // LF (#xA) + b[i] == 0xC2 && b[i+1] == 0x85 || // NEL (#x85) + b[i] == 0xE2 && b[i+1] == 0x80 && b[i+2] == 0xA8 || // LS (#x2028) + b[i] == 0xE2 && b[i+1] == 0x80 && b[i+2] == 0xA9 || // PS (#x2029) + b[i] == 0) +} + +// Check if the character is a line break, space, tab, or NUL. +func is_blankz(b []byte, i int) bool { + //return is_blank(b, i) || is_breakz(b, i) + return ( + // is_blank: + b[i] == ' ' || b[i] == '\t' || + // is_breakz: + b[i] == '\r' || // CR (#xD) + b[i] == '\n' || // LF (#xA) + b[i] == 0xC2 && b[i+1] == 0x85 || // NEL (#x85) + b[i] == 0xE2 && b[i+1] == 0x80 && b[i+2] == 0xA8 || // LS (#x2028) + b[i] == 0xE2 && b[i+1] == 0x80 && b[i+2] == 0xA9 || // PS (#x2029) + b[i] == 0) +} + +// Determine the width of the character. +func width(b byte) int { + // Don't replace these by a switch without first + // confirming that it is being inlined. + if b&0x80 == 0x00 { + return 1 + } + if b&0xE0 == 0xC0 { + return 2 + } + if b&0xF0 == 0xE0 { + return 3 + } + if b&0xF8 == 0xF0 { + return 4 + } + return 0 + +} diff --git a/vendor/modules.txt b/vendor/modules.txt index f1e3ce32c..b7fb68df9 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -72,3 +72,6 @@ google.golang.org/protobuf/types/known/emptypb google.golang.org/protobuf/types/known/structpb google.golang.org/protobuf/types/known/timestamppb google.golang.org/protobuf/types/known/wrapperspb +# gopkg.in/yaml.v3 v3.0.1 +## explicit +gopkg.in/yaml.v3 From 4adcf4d254654c43d2648acdc7335dc714c4002f Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Thu, 13 Feb 2025 16:39:09 -0800 Subject: [PATCH 11/46] Default enable DefaultUTCTimeZone (#1130) --- cel/cel_test.go | 186 +++++++++++++++++++++------- cel/env.go | 34 ++--- cel/library.go | 200 ++++-------------------------- common/stdlib/standard.go | 212 ++++++++++++++++++++++++++++---- interpreter/interpreter_test.go | 40 ++++++ policy/compiler_test.go | 1 - 6 files changed, 404 insertions(+), 269 deletions(-) diff --git a/cel/cel_test.go b/cel/cel_test.go index 77f35d549..4f873bbf7 100644 --- a/cel/cel_test.go +++ b/cel/cel_test.go @@ -2146,47 +2146,150 @@ func TestRegexOptimizer(t *testing.T) { } } -func TestDefaultUTCTimeZone(t *testing.T) { - env := testEnv(t, Variable("x", TimestampType), DefaultUTCTimeZone(true)) - out, err := interpret(t, env, ` - x.getFullYear() == 1970 - && x.getMonth() == 0 - && x.getDayOfYear() == 0 - && x.getDayOfMonth() == 0 - && x.getDate() == 1 - && x.getDayOfWeek() == 4 - && x.getHours() == 2 - && x.getMinutes() == 5 - && x.getSeconds() == 6 - && x.getMilliseconds() == 1 - && x.getFullYear('-07:30') == 1969 - && x.getDayOfYear('-07:30') == 364 - && x.getMonth('-07:30') == 11 - && x.getDayOfMonth('-07:30') == 30 - && x.getDate('-07:30') == 31 - && x.getDayOfWeek('-07:30') == 3 - && x.getHours('-07:30') == 18 - && x.getMinutes('-07:30') == 35 - && x.getSeconds('-07:30') == 6 - && x.getMilliseconds('-07:30') == 1 - && x.getFullYear('23:15') == 1970 - && x.getDayOfYear('23:15') == 1 - && x.getMonth('23:15') == 0 - && x.getDayOfMonth('23:15') == 1 - && x.getDate('23:15') == 2 - && x.getDayOfWeek('23:15') == 5 - && x.getHours('23:15') == 1 - && x.getMinutes('23:15') == 20 - && x.getSeconds('23:15') == 6 - && x.getMilliseconds('23:15') == 1`, - map[string]any{ - "x": time.Unix(7506, 1000000).Local(), - }) - if err != nil { - t.Fatalf("prg.Eval() failed: %v", err) +func TestDefaultUTCTimeZoneDisabled(t *testing.T) { + testEnvs := []struct { + name string + env *Env + }{ + {"default", testEnv(t, Variable("x", TimestampType))}, + {"enabled", testEnv(t, Variable("x", TimestampType), DefaultUTCTimeZone(true))}, + {"disabled", testEnv(t, Variable("x", TimestampType), DefaultUTCTimeZone(false))}, + } + exprs := []struct { + name string + value string + envOut map[string]ref.Val + }{ + { + name: "default-timezone", + value: ` + x.getFullYear() == 1970 + && x.getMonth() == 0 + && x.getDayOfYear() == 0 + && x.getDayOfMonth() == 0 + && x.getDate() == 1 + && x.getDayOfWeek() == 4 + && x.getHours() == 2 + && x.getMinutes() == 5 + && x.getSeconds() == 6 + && x.getMilliseconds() == 1`, + envOut: map[string]ref.Val{ + "default": types.True, + "enabled": types.True, + "disabled": types.False, + }, + }, + { + name: "default-local-year", + value: `x.getFullYear()`, + envOut: map[string]ref.Val{ + "default": types.Int(1970), + "enabled": types.Int(1970), + "disabled": types.Int(1969), + }, + }, + { + name: "default-local-day-of-year", + value: `x.getDayOfYear()`, + envOut: map[string]ref.Val{ + "default": types.Int(0), + "enabled": types.Int(0), + "disabled": types.Int(364), + }, + }, + { + name: "default-local-month", + value: `x.getMonth()`, + envOut: map[string]ref.Val{ + "default": types.Int(0), + "enabled": types.Int(0), + "disabled": types.Int(11), + }, + }, + { + name: "default-local-day-of-month", + value: ` + x.getDayOfMonth() == 30 + && x.getDate() == 31`, + envOut: map[string]ref.Val{ + "default": types.False, + "enabled": types.False, + "disabled": types.True, + }, + }, + { + name: "default-local-dates", + value: `x.getDayOfWeek()`, + envOut: map[string]ref.Val{ + "default": types.Int(4), + "enabled": types.Int(4), + "disabled": types.Int(3), + }, + }, + { + name: "default-local-times", + value: ` + x.getHours() == 18 + && x.getMinutes() == 5 + && x.getSeconds() == 6 + && x.getMilliseconds() == 1`, + envOut: map[string]ref.Val{ + "default": types.False, + "enabled": types.False, + "disabled": types.True, + }, + }, + { + name: "explicit", + value: ` + x.getFullYear('-07:30') == 1969 + && x.getDayOfYear('-07:30') == 364 + && x.getMonth('-07:30') == 11 + && x.getDayOfMonth('-07:30') == 30 + && x.getDate('-07:30') == 31 + && x.getDayOfWeek('-07:30') == 3 + && x.getHours('-07:30') == 18 + && x.getMinutes('-07:30') == 35 + && x.getSeconds('-07:30') == 6 + && x.getMilliseconds('-07:30') == 1 + && x.getFullYear('23:15') == 1970 + && x.getDayOfYear('23:15') == 1 + && x.getMonth('23:15') == 0 + && x.getDayOfMonth('23:15') == 1 + && x.getDate('23:15') == 2 + && x.getDayOfWeek('23:15') == 5 + && x.getHours('23:15') == 1 + && x.getMinutes('23:15') == 20 + && x.getSeconds('23:15') == 6 + && x.getMilliseconds('23:15') == 1`, + envOut: map[string]ref.Val{ + "default": types.True, + "enabled": types.True, + "disabled": types.True, + }, + }, } - if out != types.True { - t.Errorf("Eval() got %v, wanted true", out) + + offset, _ := time.ParseDuration("-8h") + vars := map[string]any{ + "x": time.Unix(7506, 1000000).In(time.FixedZone("", int(offset.Seconds()))), + } + for _, e := range testEnvs { + te := e + for _, expr := range exprs { + ex := expr + t.Run(fmt.Sprintf("%s/%s", te.name, ex.name), func(t *testing.T) { + env := te.env + expr := ex.value + out, err := interpret(t, env, expr, vars) + if err != nil { + t.Fatal(err) + } + if out.Equal(ex.envOut[te.name]) != types.True { + t.Errorf("interpret got %v, wanted %v", out, ex.envOut[te.name]) + } + }) + } } } @@ -2194,7 +2297,6 @@ func TestDefaultUTCTimeZoneExtension(t *testing.T) { env := testEnv(t, Variable("x", TimestampType), Variable("y", DurationType), - DefaultUTCTimeZone(true), ) env, err := env.Extend() if err != nil { @@ -2220,7 +2322,7 @@ func TestDefaultUTCTimeZoneExtension(t *testing.T) { } func TestDefaultUTCTimeZoneError(t *testing.T) { - env := testEnv(t, Variable("x", TimestampType), DefaultUTCTimeZone(true)) + env := testEnv(t, Variable("x", TimestampType)) out, err := interpret(t, env, ` x.getFullYear(':xx') == 1969 || x.getDayOfYear('xx:') == 364 diff --git a/cel/env.go b/cel/env.go index 16531bb02..5bebfbc30 100644 --- a/cel/env.go +++ b/cel/env.go @@ -718,10 +718,15 @@ func (e *Env) configure(opts []EnvOption) (*Env, error) { } } - // If the default UTC timezone fix has been enabled, make sure the library is configured - e, err = e.maybeApplyFeature(featureDefaultUTCTimeZone, Lib(timeUTCLibrary{})) - if err != nil { - return nil, err + // If the default UTC timezone has been disabled, configure the legacy overloads + if utcTime, isSet := e.features[featureDefaultUTCTimeZone]; isSet && !utcTime { + if !e.appliedFeatures[featureDefaultUTCTimeZone] { + e.appliedFeatures[featureDefaultUTCTimeZone] = true + e, err = Lib(timeLegacyLibrary{})(e) + if err != nil { + return nil, err + } + } } // Configure the parser. @@ -805,27 +810,6 @@ func (e *Env) getCheckerOrError() (*checker.Env, error) { return e.chk, e.chkErr } -// maybeApplyFeature determines whether the feature-guarded option is enabled, and if so applies -// the feature if it has not already been enabled. -func (e *Env) maybeApplyFeature(feature int, option EnvOption) (*Env, error) { - if !e.HasFeature(feature) { - return e, nil - } - _, applied := e.appliedFeatures[feature] - if applied { - return e, nil - } - e, err := option(e) - if err != nil { - return nil, err - } - // record that the feature has been applied since it will generate declarations - // and functions which will be propagated on Extend() calls and which should only - // be registered once. - e.appliedFeatures[feature] = true - return e, nil -} - // computeUnknownVars determines a set of missing variables based on the input activation and the // environment's configured declaration set. func (e *Env) computeUnknownVars(vars Activation) []*interpreter.AttributePattern { diff --git a/cel/library.go b/cel/library.go index ebe05dc93..c2fcafc11 100644 --- a/cel/library.go +++ b/cel/library.go @@ -17,9 +17,6 @@ package cel import ( "fmt" "math" - "strconv" - "strings" - "time" "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/decls" @@ -720,250 +717,99 @@ func (opt *evalOptionalOrValue) Eval(ctx interpreter.Activation) ref.Val { return opt.rhs.Eval(ctx) } -type timeUTCLibrary struct{} +type timeLegacyLibrary struct{} -func (timeUTCLibrary) CompileOptions() []EnvOption { +func (timeLegacyLibrary) CompileOptions() []EnvOption { return timeOverloadDeclarations } -func (timeUTCLibrary) ProgramOptions() []ProgramOption { +func (timeLegacyLibrary) ProgramOptions() []ProgramOption { return []ProgramOption{} } // Declarations and functions which enable using UTC on time.Time inputs when the timezone is unspecified // in the CEL expression. var ( - utcTZ = types.String("UTC") - timeOverloadDeclarations = []EnvOption{ - Function(overloads.TimeGetHours, - MemberOverload(overloads.DurationToHours, []*Type{DurationType}, IntType, - UnaryBinding(types.DurationGetHours))), - Function(overloads.TimeGetMinutes, - MemberOverload(overloads.DurationToMinutes, []*Type{DurationType}, IntType, - UnaryBinding(types.DurationGetMinutes))), - Function(overloads.TimeGetSeconds, - MemberOverload(overloads.DurationToSeconds, []*Type{DurationType}, IntType, - UnaryBinding(types.DurationGetSeconds))), - Function(overloads.TimeGetMilliseconds, - MemberOverload(overloads.DurationToMilliseconds, []*Type{DurationType}, IntType, - UnaryBinding(types.DurationGetMilliseconds))), Function(overloads.TimeGetFullYear, MemberOverload(overloads.TimestampToYear, []*Type{TimestampType}, IntType, UnaryBinding(func(ts ref.Val) ref.Val { - return timestampGetFullYear(ts, utcTZ) + t := ts.(types.Timestamp) + return t.Receive(overloads.TimeGetFullYear, overloads.TimestampToYear, []ref.Val{}) }), ), - MemberOverload(overloads.TimestampToYearWithTz, []*Type{TimestampType, StringType}, IntType, - BinaryBinding(timestampGetFullYear), - ), ), Function(overloads.TimeGetMonth, MemberOverload(overloads.TimestampToMonth, []*Type{TimestampType}, IntType, UnaryBinding(func(ts ref.Val) ref.Val { - return timestampGetMonth(ts, utcTZ) + t := ts.(types.Timestamp) + return t.Receive(overloads.TimeGetMonth, overloads.TimestampToMonth, []ref.Val{}) }), ), - MemberOverload(overloads.TimestampToMonthWithTz, []*Type{TimestampType, StringType}, IntType, - BinaryBinding(timestampGetMonth), - ), ), Function(overloads.TimeGetDayOfYear, MemberOverload(overloads.TimestampToDayOfYear, []*Type{TimestampType}, IntType, UnaryBinding(func(ts ref.Val) ref.Val { - return timestampGetDayOfYear(ts, utcTZ) - }), - ), - MemberOverload(overloads.TimestampToDayOfYearWithTz, []*Type{TimestampType, StringType}, IntType, - BinaryBinding(func(ts, tz ref.Val) ref.Val { - return timestampGetDayOfYear(ts, tz) + t := ts.(types.Timestamp) + return t.Receive(overloads.TimeGetDayOfYear, overloads.TimestampToDayOfYear, []ref.Val{}) }), ), ), Function(overloads.TimeGetDayOfMonth, MemberOverload(overloads.TimestampToDayOfMonthZeroBased, []*Type{TimestampType}, IntType, UnaryBinding(func(ts ref.Val) ref.Val { - return timestampGetDayOfMonthZeroBased(ts, utcTZ) + t := ts.(types.Timestamp) + return t.Receive(overloads.TimeGetDayOfMonth, overloads.TimestampToDayOfMonthZeroBased, []ref.Val{}) }), ), - MemberOverload(overloads.TimestampToDayOfMonthZeroBasedWithTz, []*Type{TimestampType, StringType}, IntType, - BinaryBinding(timestampGetDayOfMonthZeroBased), - ), ), Function(overloads.TimeGetDate, MemberOverload(overloads.TimestampToDayOfMonthOneBased, []*Type{TimestampType}, IntType, UnaryBinding(func(ts ref.Val) ref.Val { - return timestampGetDayOfMonthOneBased(ts, utcTZ) + t := ts.(types.Timestamp) + return t.Receive(overloads.TimeGetDate, overloads.TimestampToDayOfMonthOneBased, []ref.Val{}) }), ), - MemberOverload(overloads.TimestampToDayOfMonthOneBasedWithTz, []*Type{TimestampType, StringType}, IntType, - BinaryBinding(timestampGetDayOfMonthOneBased), - ), ), Function(overloads.TimeGetDayOfWeek, MemberOverload(overloads.TimestampToDayOfWeek, []*Type{TimestampType}, IntType, UnaryBinding(func(ts ref.Val) ref.Val { - return timestampGetDayOfWeek(ts, utcTZ) + t := ts.(types.Timestamp) + return t.Receive(overloads.TimeGetDayOfWeek, overloads.TimestampToDayOfWeek, []ref.Val{}) }), ), - MemberOverload(overloads.TimestampToDayOfWeekWithTz, []*Type{TimestampType, StringType}, IntType, - BinaryBinding(timestampGetDayOfWeek), - ), ), Function(overloads.TimeGetHours, MemberOverload(overloads.TimestampToHours, []*Type{TimestampType}, IntType, UnaryBinding(func(ts ref.Val) ref.Val { - return timestampGetHours(ts, utcTZ) + t := ts.(types.Timestamp) + return t.Receive(overloads.TimeGetHours, overloads.TimestampToHours, []ref.Val{}) }), ), - MemberOverload(overloads.TimestampToHoursWithTz, []*Type{TimestampType, StringType}, IntType, - BinaryBinding(timestampGetHours), - ), ), Function(overloads.TimeGetMinutes, MemberOverload(overloads.TimestampToMinutes, []*Type{TimestampType}, IntType, UnaryBinding(func(ts ref.Val) ref.Val { - return timestampGetMinutes(ts, utcTZ) + t := ts.(types.Timestamp) + return t.Receive(overloads.TimeGetMinutes, overloads.TimestampToMinutes, []ref.Val{}) }), ), - MemberOverload(overloads.TimestampToMinutesWithTz, []*Type{TimestampType, StringType}, IntType, - BinaryBinding(timestampGetMinutes), - ), ), Function(overloads.TimeGetSeconds, MemberOverload(overloads.TimestampToSeconds, []*Type{TimestampType}, IntType, UnaryBinding(func(ts ref.Val) ref.Val { - return timestampGetSeconds(ts, utcTZ) + t := ts.(types.Timestamp) + return t.Receive(overloads.TimeGetSeconds, overloads.TimestampToSeconds, []ref.Val{}) }), ), - MemberOverload(overloads.TimestampToSecondsWithTz, []*Type{TimestampType, StringType}, IntType, - BinaryBinding(timestampGetSeconds), - ), ), Function(overloads.TimeGetMilliseconds, MemberOverload(overloads.TimestampToMilliseconds, []*Type{TimestampType}, IntType, UnaryBinding(func(ts ref.Val) ref.Val { - return timestampGetMilliseconds(ts, utcTZ) + t := ts.(types.Timestamp) + return t.Receive(overloads.TimeGetMilliseconds, overloads.TimestampToMilliseconds, []ref.Val{}) }), ), - MemberOverload(overloads.TimestampToMillisecondsWithTz, []*Type{TimestampType, StringType}, IntType, - BinaryBinding(timestampGetMilliseconds), - ), ), } ) - -func timestampGetFullYear(ts, tz ref.Val) ref.Val { - t, err := inTimeZone(ts, tz) - if err != nil { - return types.NewErrFromString(err.Error()) - } - return types.Int(t.Year()) -} - -func timestampGetMonth(ts, tz ref.Val) ref.Val { - t, err := inTimeZone(ts, tz) - if err != nil { - return types.NewErrFromString(err.Error()) - } - // CEL spec indicates that the month should be 0-based, but the Time value - // for Month() is 1-based. - return types.Int(t.Month() - 1) -} - -func timestampGetDayOfYear(ts, tz ref.Val) ref.Val { - t, err := inTimeZone(ts, tz) - if err != nil { - return types.NewErrFromString(err.Error()) - } - return types.Int(t.YearDay() - 1) -} - -func timestampGetDayOfMonthZeroBased(ts, tz ref.Val) ref.Val { - t, err := inTimeZone(ts, tz) - if err != nil { - return types.NewErrFromString(err.Error()) - } - return types.Int(t.Day() - 1) -} - -func timestampGetDayOfMonthOneBased(ts, tz ref.Val) ref.Val { - t, err := inTimeZone(ts, tz) - if err != nil { - return types.NewErrFromString(err.Error()) - } - return types.Int(t.Day()) -} - -func timestampGetDayOfWeek(ts, tz ref.Val) ref.Val { - t, err := inTimeZone(ts, tz) - if err != nil { - return types.NewErrFromString(err.Error()) - } - return types.Int(t.Weekday()) -} - -func timestampGetHours(ts, tz ref.Val) ref.Val { - t, err := inTimeZone(ts, tz) - if err != nil { - return types.NewErrFromString(err.Error()) - } - return types.Int(t.Hour()) -} - -func timestampGetMinutes(ts, tz ref.Val) ref.Val { - t, err := inTimeZone(ts, tz) - if err != nil { - return types.NewErrFromString(err.Error()) - } - return types.Int(t.Minute()) -} - -func timestampGetSeconds(ts, tz ref.Val) ref.Val { - t, err := inTimeZone(ts, tz) - if err != nil { - return types.NewErrFromString(err.Error()) - } - return types.Int(t.Second()) -} - -func timestampGetMilliseconds(ts, tz ref.Val) ref.Val { - t, err := inTimeZone(ts, tz) - if err != nil { - return types.NewErrFromString(err.Error()) - } - return types.Int(t.Nanosecond() / 1000000) -} - -func inTimeZone(ts, tz ref.Val) (time.Time, error) { - t := ts.(types.Timestamp) - val := string(tz.(types.String)) - ind := strings.Index(val, ":") - if ind == -1 { - loc, err := time.LoadLocation(val) - if err != nil { - return time.Time{}, err - } - return t.In(loc), nil - } - - // If the input is not the name of a timezone (for example, 'US/Central'), it should be a numerical offset from UTC - // in the format ^(+|-)(0[0-9]|1[0-4]):[0-5][0-9]$. The numerical input is parsed in terms of hours and minutes. - hr, err := strconv.Atoi(string(val[0:ind])) - if err != nil { - return time.Time{}, err - } - min, err := strconv.Atoi(string(val[ind+1:])) - if err != nil { - return time.Time{}, err - } - var offset int - if string(val[0]) == "-" { - offset = hr*60 - min - } else { - offset = hr*60 + min - } - secondsEastOfUTC := int((time.Duration(offset) * time.Minute).Seconds()) - timezone := time.FixedZone("", secondsEastOfUTC) - return t.In(timezone), nil -} diff --git a/common/stdlib/standard.go b/common/stdlib/standard.go index 1550c1786..cbaa7d072 100644 --- a/common/stdlib/standard.go +++ b/common/stdlib/standard.go @@ -16,6 +16,10 @@ package stdlib import ( + "strconv" + "strings" + "time" + "github.com/google/cel-go/common/decls" "github.com/google/cel-go/common/functions" "github.com/google/cel-go/common/operators" @@ -28,6 +32,7 @@ import ( var ( stdFunctions []*decls.FunctionDecl stdTypes []*decls.VariableDecl + utcTZ = types.String("UTC") ) func init() { @@ -497,71 +502,115 @@ func init() { // Timestamp / duration functions function(overloads.TimeGetFullYear, decls.MemberOverload(overloads.TimestampToYear, - argTypes(types.TimestampType), types.IntType), + argTypes(types.TimestampType), types.IntType, + decls.UnaryBinding(func(ts ref.Val) ref.Val { + return timestampGetFullYear(ts, utcTZ) + })), decls.MemberOverload(overloads.TimestampToYearWithTz, - argTypes(types.TimestampType, types.StringType), types.IntType)), + argTypes(types.TimestampType, types.StringType), types.IntType, + decls.BinaryBinding(timestampGetFullYear))), function(overloads.TimeGetMonth, decls.MemberOverload(overloads.TimestampToMonth, - argTypes(types.TimestampType), types.IntType), + argTypes(types.TimestampType), types.IntType, + decls.UnaryBinding(func(ts ref.Val) ref.Val { + return timestampGetMonth(ts, utcTZ) + })), decls.MemberOverload(overloads.TimestampToMonthWithTz, - argTypes(types.TimestampType, types.StringType), types.IntType)), + argTypes(types.TimestampType, types.StringType), types.IntType, + decls.BinaryBinding(timestampGetMonth))), function(overloads.TimeGetDayOfYear, decls.MemberOverload(overloads.TimestampToDayOfYear, - argTypes(types.TimestampType), types.IntType), + argTypes(types.TimestampType), types.IntType, + decls.UnaryBinding(func(ts ref.Val) ref.Val { + return timestampGetDayOfYear(ts, utcTZ) + })), decls.MemberOverload(overloads.TimestampToDayOfYearWithTz, - argTypes(types.TimestampType, types.StringType), types.IntType)), + argTypes(types.TimestampType, types.StringType), types.IntType, + decls.BinaryBinding(timestampGetDayOfYear))), function(overloads.TimeGetDayOfMonth, decls.MemberOverload(overloads.TimestampToDayOfMonthZeroBased, - argTypes(types.TimestampType), types.IntType), + argTypes(types.TimestampType), types.IntType, + decls.UnaryBinding(func(ts ref.Val) ref.Val { + return timestampGetDayOfMonthZeroBased(ts, utcTZ) + })), decls.MemberOverload(overloads.TimestampToDayOfMonthZeroBasedWithTz, - argTypes(types.TimestampType, types.StringType), types.IntType)), + argTypes(types.TimestampType, types.StringType), types.IntType, + decls.BinaryBinding(timestampGetDayOfMonthZeroBased))), function(overloads.TimeGetDate, decls.MemberOverload(overloads.TimestampToDayOfMonthOneBased, - argTypes(types.TimestampType), types.IntType), + argTypes(types.TimestampType), types.IntType, + decls.UnaryBinding(func(ts ref.Val) ref.Val { + return timestampGetDayOfMonthOneBased(ts, utcTZ) + })), decls.MemberOverload(overloads.TimestampToDayOfMonthOneBasedWithTz, - argTypes(types.TimestampType, types.StringType), types.IntType)), + argTypes(types.TimestampType, types.StringType), types.IntType, + decls.BinaryBinding(timestampGetDayOfMonthOneBased))), function(overloads.TimeGetDayOfWeek, decls.MemberOverload(overloads.TimestampToDayOfWeek, - argTypes(types.TimestampType), types.IntType), + argTypes(types.TimestampType), types.IntType, + decls.UnaryBinding(func(ts ref.Val) ref.Val { + return timestampGetDayOfWeek(ts, utcTZ) + })), decls.MemberOverload(overloads.TimestampToDayOfWeekWithTz, - argTypes(types.TimestampType, types.StringType), types.IntType)), + argTypes(types.TimestampType, types.StringType), types.IntType, + decls.BinaryBinding(timestampGetDayOfWeek))), function(overloads.TimeGetHours, decls.MemberOverload(overloads.TimestampToHours, - argTypes(types.TimestampType), types.IntType), + argTypes(types.TimestampType), types.IntType, + decls.UnaryBinding(func(ts ref.Val) ref.Val { + return timestampGetHours(ts, utcTZ) + })), decls.MemberOverload(overloads.TimestampToHoursWithTz, - argTypes(types.TimestampType, types.StringType), types.IntType), + argTypes(types.TimestampType, types.StringType), types.IntType, + decls.BinaryBinding(timestampGetHours)), decls.MemberOverload(overloads.DurationToHours, - argTypes(types.DurationType), types.IntType)), + argTypes(types.DurationType), types.IntType, + decls.UnaryBinding(types.DurationGetHours))), function(overloads.TimeGetMinutes, decls.MemberOverload(overloads.TimestampToMinutes, - argTypes(types.TimestampType), types.IntType), + argTypes(types.TimestampType), types.IntType, + decls.UnaryBinding(func(ts ref.Val) ref.Val { + return timestampGetMinutes(ts, utcTZ) + })), decls.MemberOverload(overloads.TimestampToMinutesWithTz, - argTypes(types.TimestampType, types.StringType), types.IntType), + argTypes(types.TimestampType, types.StringType), types.IntType, + decls.BinaryBinding(timestampGetMinutes)), decls.MemberOverload(overloads.DurationToMinutes, - argTypes(types.DurationType), types.IntType)), + argTypes(types.DurationType), types.IntType, + decls.UnaryBinding(types.DurationGetMinutes))), function(overloads.TimeGetSeconds, decls.MemberOverload(overloads.TimestampToSeconds, - argTypes(types.TimestampType), types.IntType), + argTypes(types.TimestampType), types.IntType, + decls.UnaryBinding(func(ts ref.Val) ref.Val { + return timestampGetSeconds(ts, utcTZ) + })), decls.MemberOverload(overloads.TimestampToSecondsWithTz, - argTypes(types.TimestampType, types.StringType), types.IntType), + argTypes(types.TimestampType, types.StringType), types.IntType, + decls.BinaryBinding(timestampGetSeconds)), decls.MemberOverload(overloads.DurationToSeconds, - argTypes(types.DurationType), types.IntType)), + argTypes(types.DurationType), types.IntType, + decls.UnaryBinding(types.DurationGetSeconds))), function(overloads.TimeGetMilliseconds, decls.MemberOverload(overloads.TimestampToMilliseconds, - argTypes(types.TimestampType), types.IntType), + argTypes(types.TimestampType), types.IntType, + decls.UnaryBinding(func(ts ref.Val) ref.Val { + return timestampGetMilliseconds(ts, utcTZ) + })), decls.MemberOverload(overloads.TimestampToMillisecondsWithTz, - argTypes(types.TimestampType, types.StringType), types.IntType), + argTypes(types.TimestampType, types.StringType), types.IntType, + decls.BinaryBinding(timestampGetMilliseconds)), decls.MemberOverload(overloads.DurationToMilliseconds, - argTypes(types.DurationType), types.IntType)), + argTypes(types.DurationType), types.IntType, + decls.UnaryBinding(types.DurationGetMilliseconds))), } } @@ -618,3 +667,118 @@ func convertToType(t ref.Type) functions.UnaryOp { return val.ConvertToType(t) } } + +func timestampGetFullYear(ts, tz ref.Val) ref.Val { + t, err := inTimeZone(ts, tz) + if err != nil { + return types.NewErrFromString(err.Error()) + } + return types.Int(t.Year()) +} + +func timestampGetMonth(ts, tz ref.Val) ref.Val { + t, err := inTimeZone(ts, tz) + if err != nil { + return types.NewErrFromString(err.Error()) + } + // CEL spec indicates that the month should be 0-based, but the Time value + // for Month() is 1-based. + return types.Int(t.Month() - 1) +} + +func timestampGetDayOfYear(ts, tz ref.Val) ref.Val { + t, err := inTimeZone(ts, tz) + if err != nil { + return types.NewErrFromString(err.Error()) + } + return types.Int(t.YearDay() - 1) +} + +func timestampGetDayOfMonthZeroBased(ts, tz ref.Val) ref.Val { + t, err := inTimeZone(ts, tz) + if err != nil { + return types.NewErrFromString(err.Error()) + } + return types.Int(t.Day() - 1) +} + +func timestampGetDayOfMonthOneBased(ts, tz ref.Val) ref.Val { + t, err := inTimeZone(ts, tz) + if err != nil { + return types.NewErrFromString(err.Error()) + } + return types.Int(t.Day()) +} + +func timestampGetDayOfWeek(ts, tz ref.Val) ref.Val { + t, err := inTimeZone(ts, tz) + if err != nil { + return types.NewErrFromString(err.Error()) + } + return types.Int(t.Weekday()) +} + +func timestampGetHours(ts, tz ref.Val) ref.Val { + t, err := inTimeZone(ts, tz) + if err != nil { + return types.NewErrFromString(err.Error()) + } + return types.Int(t.Hour()) +} + +func timestampGetMinutes(ts, tz ref.Val) ref.Val { + t, err := inTimeZone(ts, tz) + if err != nil { + return types.NewErrFromString(err.Error()) + } + return types.Int(t.Minute()) +} + +func timestampGetSeconds(ts, tz ref.Val) ref.Val { + t, err := inTimeZone(ts, tz) + if err != nil { + return types.NewErrFromString(err.Error()) + } + return types.Int(t.Second()) +} + +func timestampGetMilliseconds(ts, tz ref.Val) ref.Val { + t, err := inTimeZone(ts, tz) + if err != nil { + return types.NewErrFromString(err.Error()) + } + return types.Int(t.Nanosecond() / 1000000) +} + +func inTimeZone(ts, tz ref.Val) (time.Time, error) { + t := ts.(types.Timestamp) + val := string(tz.(types.String)) + ind := strings.Index(val, ":") + if ind == -1 { + loc, err := time.LoadLocation(val) + if err != nil { + return time.Time{}, err + } + return t.In(loc), nil + } + + // If the input is not the name of a timezone (for example, 'US/Central'), it should be a numerical offset from UTC + // in the format ^(+|-)(0[0-9]|1[0-4]):[0-5][0-9]$. The numerical input is parsed in terms of hours and minutes. + hr, err := strconv.Atoi(string(val[0:ind])) + if err != nil { + return time.Time{}, err + } + min, err := strconv.Atoi(string(val[ind+1:])) + if err != nil { + return time.Time{}, err + } + var offset int + if string(val[0]) == "-" { + offset = hr*60 - min + } else { + offset = hr*60 + min + } + secondsEastOfUTC := int((time.Duration(offset) * time.Minute).Seconds()) + timezone := time.FixedZone("", secondsEastOfUTC) + return t.In(timezone), nil +} diff --git a/interpreter/interpreter_test.go b/interpreter/interpreter_test.go index cf1f56de1..498be1340 100644 --- a/interpreter/interpreter_test.go +++ b/interpreter/interpreter_test.go @@ -707,6 +707,46 @@ func testData(t testing.TB) []testCase { name: "timestamp_ge_timestamp", expr: `timestamp(2) >= timestamp(2)`, }, + { + name: "timestamp_methods", + vars: []*decls.VariableDecl{ + decls.NewVariable("x", types.TimestampType), + }, + in: map[string]any{ + "x": time.Unix(7506, 1000000).Local(), + }, + expr: ` + x.getFullYear() == 1970 + && x.getMonth() == 0 + && x.getDayOfYear() == 0 + && x.getDayOfMonth() == 0 + && x.getDate() == 1 + && x.getDayOfWeek() == 4 + && x.getHours() == 2 + && x.getMinutes() == 5 + && x.getSeconds() == 6 + && x.getMilliseconds() == 1 + && x.getFullYear('-07:30') == 1969 + && x.getDayOfYear('-07:30') == 364 + && x.getMonth('-07:30') == 11 + && x.getDayOfMonth('-07:30') == 30 + && x.getDate('-07:30') == 31 + && x.getDayOfWeek('-07:30') == 3 + && x.getHours('-07:30') == 18 + && x.getMinutes('-07:30') == 35 + && x.getSeconds('-07:30') == 6 + && x.getMilliseconds('-07:30') == 1 + && x.getFullYear('23:15') == 1970 + && x.getDayOfYear('23:15') == 1 + && x.getMonth('23:15') == 0 + && x.getDayOfMonth('23:15') == 1 + && x.getDate('23:15') == 2 + && x.getDayOfWeek('23:15') == 5 + && x.getHours('23:15') == 1 + && x.getMinutes('23:15') == 20 + && x.getSeconds('23:15') == 6 + && x.getMilliseconds('23:15') == 1`, + }, { name: "string_to_timestamp", expr: `timestamp('1986-04-26T01:23:40Z')`, diff --git a/policy/compiler_test.go b/policy/compiler_test.go index 9a4497846..c85f8702e 100644 --- a/policy/compiler_test.go +++ b/policy/compiler_test.go @@ -159,7 +159,6 @@ func compile(t testing.TB, name string, parseOpts []ParserOption, envOpts []cel. t.Errorf("policy name is %v, wanted %q", policy.name, name) } env, err := cel.NewEnv( - cel.DefaultUTCTimeZone(true), cel.OptionalTypes(), cel.EnableMacroCallTracking(), cel.ExtendedValidations(), From 4b27149545c7ca116b937af501cd1737e4448a4a Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Thu, 13 Feb 2025 18:07:00 -0800 Subject: [PATCH 12/46] Option to configure CEL via env.Config object (#1129) * FromConfig option to configure CEL from an external config file * Validation tests and cel package tests --- cel/cel_test.go | 61 +++++-- cel/env_test.go | 356 +++++++++++++++++++++++++++++++++++++--- cel/library.go | 13 +- cel/options.go | 130 +++++++++++++++ common/env/env.go | 302 ++++++++++++++++++++++++++-------- common/env/env_test.go | 316 +++++++++++++++++++++++++++++------ policy/compiler_test.go | 8 +- policy/config.go | 121 +++----------- policy/config_test.go | 30 ++-- policy/helper_test.go | 4 +- 10 files changed, 1053 insertions(+), 288 deletions(-) diff --git a/cel/cel_test.go b/cel/cel_test.go index 4f873bbf7..3ff459756 100644 --- a/cel/cel_test.go +++ b/cel/cel_test.go @@ -361,19 +361,20 @@ func TestExtendStdlibFunction(t *testing.T) { } func TestSubsetStdLib(t *testing.T) { - env, err := NewCustomEnv(StdLib(StdLibSubset( - &env.LibrarySubset{ - IncludeMacros: []string{"has"}, - IncludeFunctions: []*env.Function{ - {Name: operators.Equals}, - {Name: operators.NotEquals}, - {Name: operators.LogicalAnd}, - {Name: operators.LogicalOr}, - {Name: operators.LogicalNot}, - {Name: overloads.Size, Overloads: []*env.Overload{{ID: "list_size"}}}, + env, err := NewCustomEnv( + StdLib(StdLibSubset( + &env.LibrarySubset{ + IncludeMacros: []string{"has"}, + IncludeFunctions: []*env.Function{ + {Name: operators.Equals}, + {Name: operators.NotEquals}, + {Name: operators.LogicalAnd}, + {Name: operators.LogicalOr}, + {Name: operators.LogicalNot}, + {Name: overloads.Size, Overloads: []*env.Overload{{ID: "list_size"}}}, + }, }, - }, - ))) + ))) if err != nil { t.Fatalf("StdLib() subsetting failed: %v", err) } @@ -443,6 +444,42 @@ func TestSubsetStdLib(t *testing.T) { } } +func TestSubsetStdLibError(t *testing.T) { + _, err := NewCustomEnv( + StdLib(StdLibSubset( + env.NewLibrarySubset().AddIncludedMacros("has").AddExcludedMacros("exists")), + )) + if err == nil || !strings.Contains(err.Error(), "invalid subset") { + t.Errorf("StdLib() subsetting got %v, wanted error 'invalid subset'", err) + } +} + +func TestSubsetStdLibMerge(t *testing.T) { + _, err := NewCustomEnv( + Function("size", MemberOverload("string_size", []*Type{StringType}, IntType)), + StdLib(StdLibSubset( + env.NewLibrarySubset().AddIncludedFunctions([]*env.Function{ + {Name: overloads.Size, Overloads: []*env.Overload{{ID: "string_size"}}}, + }...), + ))) + if err != nil { + t.Errorf("StdLib() subsetting failed to merge: %v", err) + } +} + +func TestSubsetStdLibMergeError(t *testing.T) { + _, err := NewCustomEnv( + Function("size", MemberOverload("string_size", []*Type{StringType}, UintType)), + StdLib(StdLibSubset( + env.NewLibrarySubset().AddIncludedFunctions([]*env.Function{ + {Name: overloads.Size, Overloads: []*env.Overload{{ID: "string_size"}}}, + }...), + ))) + if err == nil || !strings.Contains(err.Error(), "merge failed") { + t.Errorf("StdLib() subsetting got %v, wanted error 'merge failed'", err) + } +} + func TestCustomTypes(t *testing.T) { reg := types.NewEmptyRegistry() env := testEnv(t, diff --git a/cel/env_test.go b/cel/env_test.go index 64e1f2873..55e16d981 100644 --- a/cel/env_test.go +++ b/cel/env_test.go @@ -15,9 +15,11 @@ package cel import ( + "errors" "fmt" "math" "reflect" + "strings" "sync" "testing" @@ -27,6 +29,8 @@ import ( "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" + "google.golang.org/protobuf/proto" + proto3pb "github.com/google/cel-go/test/proto3pb" exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" ) @@ -306,27 +310,27 @@ func TestFunctions(t *testing.T) { func TestEnvToConfig(t *testing.T) { tests := []struct { - name string - opts []EnvOption - wantConfig *env.Config + name string + opts []EnvOption + want *env.Config }{ { - name: "std env", - wantConfig: env.NewConfig("std env"), + name: "std env", + want: env.NewConfig("std env"), }, { name: "std env - container", opts: []EnvOption{ Container("example.container"), }, - wantConfig: env.NewConfig("std env - container").SetContainer("example.container"), + want: env.NewConfig("std env - container").SetContainer("example.container"), }, { name: "std env - aliases", opts: []EnvOption{ Abbrevs("example.type.name"), }, - wantConfig: env.NewConfig("std env - aliases").AddImports(env.NewImport("example.type.name")), + want: env.NewConfig("std env - aliases").AddImports(env.NewImport("example.type.name")), }, { name: "std env disabled", @@ -335,7 +339,7 @@ func TestEnvToConfig(t *testing.T) { return NewCustomEnv() }, }, - wantConfig: env.NewConfig("std env disabled").SetStdLib( + want: env.NewConfig("std env disabled").SetStdLib( env.NewLibrarySubset().SetDisabled(true)), }, { @@ -343,15 +347,15 @@ func TestEnvToConfig(t *testing.T) { opts: []EnvOption{ Variable("var", IntType), }, - wantConfig: env.NewConfig("std env - with variable").AddVariables(env.NewVariable("var", env.NewTypeDesc("int"))), + want: env.NewConfig("std env - with variable").AddVariables(env.NewVariable("var", env.NewTypeDesc("int"))), }, { name: "std env - with function", opts: []EnvOption{Function("hello", Overload("hello_string", []*Type{StringType}, StringType))}, - wantConfig: env.NewConfig("std env - with function").AddFunctions( - env.NewFunction("hello", []*env.Overload{ + want: env.NewConfig("std env - with function").AddFunctions( + env.NewFunction("hello", env.NewOverload("hello_string", - []*env.TypeDesc{env.NewTypeDesc("string")}, env.NewTypeDesc("string"))}, + []*env.TypeDesc{env.NewTypeDesc("string")}, env.NewTypeDesc("string")), )), }, { @@ -359,14 +363,14 @@ func TestEnvToConfig(t *testing.T) { opts: []EnvOption{ OptionalTypes(), }, - wantConfig: env.NewConfig("optional lib").AddExtensions(env.NewExtension("optional", math.MaxUint32)), + want: env.NewConfig("optional lib").AddExtensions(env.NewExtension("optional", math.MaxUint32)), }, { name: "optional lib - versioned", opts: []EnvOption{ OptionalTypes(OptionalTypesVersion(1)), }, - wantConfig: env.NewConfig("optional lib - versioned").AddExtensions(env.NewExtension("optional", 1)), + want: env.NewConfig("optional lib - versioned").AddExtensions(env.NewExtension("optional", 1)), }, { name: "optional lib - alt last()", @@ -374,11 +378,11 @@ func TestEnvToConfig(t *testing.T) { OptionalTypes(), Function("last", MemberOverload("string_last", []*Type{StringType}, StringType)), }, - wantConfig: env.NewConfig("optional lib - alt last()"). + want: env.NewConfig("optional lib - alt last()"). AddExtensions(env.NewExtension("optional", math.MaxUint32)). - AddFunctions(env.NewFunction("last", []*env.Overload{ + AddFunctions(env.NewFunction("last", env.NewMemberOverload("string_last", env.NewTypeDesc("string"), []*env.TypeDesc{}, env.NewTypeDesc("string")), - })), + )), }, { name: "context proto - with extra variable", @@ -386,7 +390,7 @@ func TestEnvToConfig(t *testing.T) { DeclareContextProto((&proto3pb.TestAllTypes{}).ProtoReflect().Descriptor()), Variable("extra", StringType), }, - wantConfig: env.NewConfig("context proto - with extra variable"). + want: env.NewConfig("context proto - with extra variable"). SetContextVariable(env.NewContextVariable("google.expr.proto3.test.TestAllTypes")). AddVariables(env.NewVariable("extra", env.NewTypeDesc("string"))), }, @@ -395,7 +399,7 @@ func TestEnvToConfig(t *testing.T) { opts: []EnvOption{ DeclareContextProto((&proto3pb.TestAllTypes{}).ProtoReflect().Descriptor()), }, - wantConfig: env.NewConfig("context proto").SetContextVariable(env.NewContextVariable("google.expr.proto3.test.TestAllTypes")), + want: env.NewConfig("context proto").SetContextVariable(env.NewContextVariable("google.expr.proto3.test.TestAllTypes")), }, } @@ -410,8 +414,309 @@ func TestEnvToConfig(t *testing.T) { if err != nil { t.Fatalf("ToConfig() failed: %v", err) } - if !reflect.DeepEqual(gotConfig, tc.wantConfig) { - t.Errorf("e.Config() got %v, wanted %v", gotConfig, tc.wantConfig) + if !reflect.DeepEqual(gotConfig, tc.want) { + t.Errorf("e.Config() got %v, wanted %v", gotConfig, tc.want) + } + }) + } +} + +func TestEnvFromConfig(t *testing.T) { + type exprCase struct { + name string + in any + expr string + iss error + out ref.Val + } + tests := []struct { + name string + beforeOpts []EnvOption + afterOpts []EnvOption + conf *env.Config + exprs []exprCase + }{ + { + name: "std env", + conf: env.NewConfig("std env"), + exprs: []exprCase{ + { + name: "literal", + expr: "'hello world'", + out: types.String("hello world"), + }, + { + name: "size", + expr: "'hello world'.size()", + out: types.Int(11), + }, + }, + }, + { + name: "std env - imports", + beforeOpts: []EnvOption{Types(&proto3pb.TestAllTypes{})}, + conf: env.NewConfig("std env - context proto"). + AddImports(env.NewImport("google.expr.proto3.test.TestAllTypes")), + exprs: []exprCase{ + { + name: "literal", + expr: "TestAllTypes{single_int64: 15}.single_int64", + out: types.Int(15), + }, + }, + }, + { + name: "std env - context proto", + beforeOpts: []EnvOption{Types(&proto3pb.TestAllTypes{})}, + conf: env.NewConfig("std env - context proto"). + SetContainer("google.expr.proto3.test"). + SetContextVariable(env.NewContextVariable("google.expr.proto3.test.TestAllTypes")), + exprs: []exprCase{ + { + name: "field select literal", + in: mustContextProto(t, &proto3pb.TestAllTypes{SingleInt64: 10}), + expr: "TestAllTypes{single_int64: single_int64}.single_int64", + out: types.Int(10), + }, + }, + }, + { + name: "custom env - variables", + beforeOpts: []EnvOption{Types(&proto3pb.TestAllTypes{})}, + conf: env.NewConfig("custom env - variables"). + SetStdLib(env.NewLibrarySubset().SetDisabled(true)). + SetContainer("google.expr.proto3.test"). + AddVariables(env.NewVariable("single_int64", env.NewTypeDesc("int"))), + exprs: []exprCase{ + { + name: "field select literal", + in: map[string]any{"single_int64": 42}, + expr: "TestAllTypes{single_int64: single_int64}.single_int64", + out: types.Int(42), + }, + { + name: "invalid operator", + in: map[string]any{"single_int64": 42}, + expr: "TestAllTypes{single_int64: single_int64}.single_int64 + 1", + iss: errors.New("undeclared reference"), + }, + }, + }, + { + name: "custom env - functions", + afterOpts: []EnvOption{ + Function("plus", + MemberOverload("int_plus_int", []*Type{IntType, IntType}, IntType, + BinaryBinding(func(lhs, rhs ref.Val) ref.Val { + l := lhs.(types.Int) + r := rhs.(types.Int) + return l + r + }), + ), + )}, + conf: env.NewConfig("custom env - functions"). + SetStdLib(env.NewLibrarySubset().SetDisabled(true)). + AddVariables(env.NewVariable("x", env.NewTypeDesc("int"))). + AddFunctions(env.NewFunction("plus", + env.NewMemberOverload("int_plus_int", + env.NewTypeDesc("int"), + []*env.TypeDesc{env.NewTypeDesc("int")}, + env.NewTypeDesc("int"), + ), + )), + exprs: []exprCase{ + { + name: "plus", + in: map[string]any{"x": 42}, + expr: "x.plus(2)", + out: types.Int(44), + }, + { + name: "plus invalid type", + in: map[string]any{"x": 42}, + expr: "x.plus(2.0)", + iss: errors.New("no matching overload"), + }, + }, + }, + { + name: "pure custom env", + beforeOpts: []EnvOption{func(*Env) (*Env, error) { + return NewCustomEnv() + }}, + conf: env.NewConfig("pure custom env").SetStdLib( + env.NewLibrarySubset().AddIncludedFunctions([]*env.Function{{Name: "_==_"}}...), + ), + exprs: []exprCase{ + { + name: "equals", + expr: "'hello world' == 'hello'", + out: types.False, + }, + { + name: "not equals - invalid", + expr: "'hello world' != 'hello'", + iss: errors.New("undeclared reference"), + }, + }, + }, + { + name: "std env - allow subset", + conf: env.NewConfig("std env - allow subset").SetStdLib( + env.NewLibrarySubset().AddIncludedFunctions([]*env.Function{{Name: "_==_"}}...), + ), + exprs: []exprCase{ + { + name: "equals", + expr: "'hello world' == 'hello'", + out: types.False, + }, + { + name: "not equals - invalid", + expr: "'hello world' != 'hello'", + iss: errors.New("undeclared reference"), + }, + }, + }, + { + name: "std env - deny subset", + conf: env.NewConfig("std env - deny subset").SetStdLib( + env.NewLibrarySubset().AddExcludedFunctions([]*env.Function{{Name: "size"}}...), + ), + exprs: []exprCase{ + { + name: "size - invalid", + expr: "'hello world'.size()", + iss: errors.New("undeclared reference"), + }, + { + name: "equals", + expr: "'hello world' == 'hello'", + out: types.False, + }, + }, + }, + { + name: "extensions", + conf: env.NewConfig("extensions"). + AddVariables( + env.NewVariable("m", + env.NewTypeDesc("map", env.NewTypeDesc("string"), env.NewTypeDesc("string")))). + AddExtensions(env.NewExtension("optional", math.MaxUint32)), + exprs: []exprCase{ + { + name: "optional none", + expr: "optional.none()", + out: types.OptionalNone, + }, + { + name: "optional key", + expr: "m.?key.hasValue()", + in: map[string]any{"m": map[string]string{"key": "value"}}, + out: types.True, + }, + }, + }, + } + for _, tst := range tests { + tc := tst + t.Run(tc.name, func(t *testing.T) { + opts := tc.beforeOpts + opts = append(opts, FromConfig(tc.conf, func(elem any) (EnvOption, bool) { + if ext, ok := elem.(*env.Extension); ok && ext.Name == "optional" { + ver, _ := ext.GetVersion() + return OptionalTypes(OptionalTypesVersion(ver)), true + } + return nil, false + })) + opts = append(opts, tc.afterOpts...) + var e *Env + var err error + if tc.conf.StdLib != nil { + e, err = NewCustomEnv(opts...) + } else { + e, err = NewEnv(opts...) + } + if err != nil { + t.Fatalf("NewEnv(FromConfig()) failed: %v", err) + } + for _, ex := range tc.exprs { + t.Run(ex.name, func(t *testing.T) { + ast, iss := e.Compile(ex.expr) + if iss.Err() != nil { + if ex.iss == nil || !strings.Contains(iss.Err().Error(), ex.iss.Error()) { + t.Errorf("e.Compile() failed with %v, wanted %v", iss.Err(), ex.iss) + } + return + } + if ex.iss != nil { + t.Fatalf("e.Compile() succeeded, wanted error %v", ex.iss) + } + prg, err := e.Program(ast) + if err != nil { + t.Fatalf("e.Program() failed: %v", err) + } + var in any = map[string]any{} + if ex.in != nil { + in = ex.in + } + out, _, err := prg.Eval(in) + if err != nil { + t.Fatalf("prg.Eval() failed: %v", err) + } + if out.Equal(ex.out) != types.True { + t.Errorf("prg.Eval() got %v, wanted %v", out, ex.out) + } + }) + } + }) + } +} + +func TestEnvFromConfigErrors(t *testing.T) { + tests := []struct { + name string + conf *env.Config + want error + }{ + { + name: "invalid subset", + conf: env.NewConfig("invalid subset").SetStdLib(env.NewLibrarySubset().SetDisableMacros(true)), + want: errors.New("invalid subset"), + }, + { + name: "invalid import", + conf: env.NewConfig("invalid import").AddImports(env.NewImport("")), + want: errors.New("invalid import"), + }, + { + name: "invalid context proto", + conf: env.NewConfig("invalid context proto").SetContextVariable(env.NewContextVariable("invalid")), + want: errors.New("invalid context proto type"), + }, + { + name: "undefined variable type", + conf: env.NewConfig("undefined variable type").AddVariables(env.NewVariable("undef", env.NewTypeDesc("undefined"))), + want: errors.New("invalid variable"), + }, + { + name: "undefined function type", + conf: env.NewConfig("undefined function type").AddFunctions(env.NewFunction("invalid", env.NewOverload("invalid", []*env.TypeDesc{}, env.NewTypeDesc("undefined")))), + want: errors.New("invalid function"), + }, + { + name: "unrecognized extension", + conf: env.NewConfig("unrecognized extension"). + AddExtensions(env.NewExtension("optional", math.MaxUint32)), + want: errors.New("unrecognized extension"), + }, + } + for _, tst := range tests { + tc := tst + t.Run(tc.name, func(t *testing.T) { + _, err := NewEnv(FromConfig(tc.conf)) + if err == nil || !strings.Contains(err.Error(), tc.want.Error()) { + t.Fatalf("NewEnv(FromConfig()) got %v, wanted error containing %v", err, tc.want) } }) } @@ -515,6 +820,15 @@ func BenchmarkEnvExtendEagerDecls(b *testing.B) { } } +func mustContextProto(t *testing.T, pb proto.Message) Activation { + t.Helper() + ctx, err := ContextProtoVars(pb) + if err != nil { + t.Fatalf("ContextProtoVars() failed: %v", err) + } + return ctx +} + type customLegacyProvider struct { provider ref.TypeProvider } diff --git a/cel/library.go b/cel/library.go index c2fcafc11..fde16a019 100644 --- a/cel/library.go +++ b/cel/library.go @@ -150,11 +150,6 @@ func (*stdLibrary) LibraryAlias() string { return "stdlib" } -// LibraryVersion returns the version of the library. -func (*stdLibrary) LibraryVersion() uint32 { - return math.MaxUint32 -} - // LibrarySubset returns the env.LibrarySubset definition associated with the CEL Library. func (lib *stdLibrary) LibrarySubset() *env.LibrarySubset { return lib.subset @@ -183,6 +178,10 @@ func (lib *stdLibrary) CompileOptions() []EnvOption { return []EnvOption{ func(e *Env) (*Env, error) { var err error + if err = lib.subset.Validate(); err != nil { + return nil, err + } + e.variables = append(e.variables, stdlib.Types()...) for _, fn := range funcs { existing, found := e.functions[fn.Name()] if found { @@ -195,10 +194,6 @@ func (lib *stdLibrary) CompileOptions() []EnvOption { } return e, nil }, - func(e *Env) (*Env, error) { - e.variables = append(e.variables, stdlib.Types()...) - return e, nil - }, Macros(macros...), } } diff --git a/cel/options.go b/cel/options.go index 8b170d5de..06d37049d 100644 --- a/cel/options.go +++ b/cel/options.go @@ -15,6 +15,7 @@ package cel import ( + "errors" "fmt" "google.golang.org/protobuf/proto" @@ -26,6 +27,7 @@ import ( "github.com/google/cel-go/checker" "github.com/google/cel-go/common/containers" "github.com/google/cel-go/common/decls" + "github.com/google/cel-go/common/env" "github.com/google/cel-go/common/functions" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/pb" @@ -113,6 +115,8 @@ func CustomTypeProvider(provider any) EnvOption { // Note: Declarations will by default be appended to the pre-existing declaration set configured // for the environment. The NewEnv call builds on top of the standard CEL declarations. For a // purely custom set of declarations use NewCustomEnv. +// +// Deprecated: use FunctionDecls and VariableDecls or FromConfig instead. func Declarations(decls ...*exprpb.Decl) EnvOption { declOpts := []EnvOption{} var err error @@ -427,6 +431,132 @@ func OptimizeRegex(regexOptimizations ...*interpreter.RegexOptimization) Program } } +// ConfigOptionFactory declares a signature which accepts a configuration element, e.g. env.Extension +// and optionally produces an EnvOption in response. +// +// If there are multiple ConfigOptionFactory values which could apply to the same configuration node +// the first one that returns an EnvOption and a `true` response will be used, and the config node +// will not be passed along to any other option factory. +// +// Only the *env.Extension type is provided at this time, but validators, optimizers, and other tuning +// parameters may be supported in the future. +type ConfigOptionFactory func(any) (EnvOption, bool) + +// FromConfig produces and applies a set of EnvOption values derived from an env.Config object. +// +// For configuration elements which refer to features outside of the `cel` package, an optional set of +// ConfigOptionFactory values may be passed in to support the conversion from static configuration to +// configured cel.Env value. +// +// Note: disabling the standard library will clear the EnvOptions values previously set for the +// environment with the exception of propagating types and adapters over to the new environment. +// +// Note: to support custom types referenced in the configuration file, you must ensure that one of +// the following options appears before the FromConfig option: Types, TypeDescs, or CustomTypeProvider +// as the type provider configured at the time when the config is processed is the one used to derive +// type references from the configuration. +func FromConfig(config *env.Config, optFactories ...ConfigOptionFactory) EnvOption { + return func(env *Env) (*Env, error) { + if err := config.Validate(); err != nil { + return nil, err + } + opts, err := configToEnvOptions(config, env.CELTypeProvider(), optFactories) + if err != nil { + return nil, err + } + for _, o := range opts { + env, err = o(env) + if err != nil { + return nil, err + } + } + return env, nil + } +} + +// configToEnvOptions generates a set of EnvOption values (or error) based on a config, a type provider, +// and an optional set of environment options. +func configToEnvOptions(config *env.Config, provider types.Provider, optFactories []ConfigOptionFactory) ([]EnvOption, error) { + // note: ported from cel-go/policy/config.go + envOpts := []EnvOption{} + // Configure the standard lib subset. + if config.StdLib != nil { + envOpts = append(envOpts, func(e *Env) (*Env, error) { + if e.HasLibrary("cel.lib.std") { + return nil, errors.New("invalid subset of stdlib: create a custom env") + } + return e, nil + }) + if !config.StdLib.Disabled { + envOpts = append(envOpts, StdLib(StdLibSubset(config.StdLib))) + } + } else { + envOpts = append(envOpts, StdLib()) + } + + // Configure the container + if config.Container != "" { + envOpts = append(envOpts, Container(config.Container)) + } + + // Configure abbreviations + for _, imp := range config.Imports { + envOpts = append(envOpts, Abbrevs(imp.Name)) + } + + // Configure the context variable declaration + if config.ContextVariable != nil { + typeName := config.ContextVariable.TypeName + if _, found := provider.FindStructType(typeName); !found { + return nil, fmt.Errorf("invalid context proto type: %q", typeName) + } + // Attempt to instantiate the proto in order to reflect to its descriptor + msg := provider.NewValue(typeName, map[string]ref.Val{}) + pbMsg, ok := msg.Value().(proto.Message) + if !ok { + return nil, fmt.Errorf("unsupported context type: %T", msg.Value()) + } + envOpts = append(envOpts, DeclareContextProto(pbMsg.ProtoReflect().Descriptor())) + } + + if len(config.Variables) != 0 { + vars := make([]*decls.VariableDecl, 0, len(config.Variables)) + for _, v := range config.Variables { + vDef, err := v.AsCELVariable(provider) + if err != nil { + return nil, err + } + vars = append(vars, vDef) + } + envOpts = append(envOpts, VariableDecls(vars...)) + } + if len(config.Functions) != 0 { + funcs := make([]*decls.FunctionDecl, 0, len(config.Functions)) + for _, f := range config.Functions { + fnDef, err := f.AsCELFunction(provider) + if err != nil { + return nil, err + } + funcs = append(funcs, fnDef) + } + envOpts = append(envOpts, FunctionDecls(funcs...)) + } + for _, e := range config.Extensions { + extHandled := false + for _, optFac := range optFactories { + if opt, useOption := optFac(e); useOption { + envOpts = append(envOpts, opt) + extHandled = true + break + } + } + if !extHandled { + return nil, fmt.Errorf("unrecognized extension: %s", e.Name) + } + } + return envOpts, nil +} + // EvalOption indicates an evaluation option that may affect the evaluation behavior or information // in the output result. type EvalOption int diff --git a/common/env/env.go b/common/env/env.go index 27e28cfd7..10c7b1e72 100644 --- a/common/env/env.go +++ b/common/env/env.go @@ -37,7 +37,7 @@ func NewConfig(name string) *Config { // // Note: custom validations, feature flags, and performance tuning parameters are not (yet) // considered part of the core CEL environment configuration and should be managed separately -// until a common convention for such configuration settings is developed. +// until a common convention for such settings is developed. type Config struct { Name string `yaml:"name,omitempty"` Description string `yaml:"description,omitempty"` @@ -50,6 +50,44 @@ type Config struct { Functions []*Function `yaml:"functions,omitempty"` } +// Validate validates the whole configuration is well-formed. +func (c *Config) Validate() error { + if c == nil { + return nil + } + var errs []error + for _, imp := range c.Imports { + if err := imp.Validate(); err != nil { + errs = append(errs, err) + } + } + if err := c.StdLib.Validate(); err != nil { + errs = append(errs, err) + } + for _, ext := range c.Extensions { + if err := ext.Validate(); err != nil { + errs = append(errs, err) + } + } + if err := c.ContextVariable.Validate(); err != nil { + errs = append(errs, err) + } + if c.ContextVariable != nil && len(c.Variables) != 0 { + errs = append(errs, errors.New("invalid config: either context variable or variables may be set, but not both")) + } + for _, v := range c.Variables { + if err := v.Validate(); err != nil { + errs = append(errs, err) + } + } + for _, fn := range c.Functions { + if err := fn.Validate(); err != nil { + errs = append(errs, err) + } + } + return errors.Join(errs...) +} + // SetContainer configures the container name for this configuration. func (c *Config) SetContainer(container string) *Config { c.Container = container @@ -105,7 +143,7 @@ func (c *Config) AddFunctionDecls(funcs ...*decls.FunctionDecl) *Config { overloads = append(overloads, NewOverload(overloadID, args, ret)) } } - convFuncs[i] = NewFunction(fn.Name(), overloads) + convFuncs[i] = NewFunction(fn.Name(), overloads...) } return c.AddFunctions(convFuncs...) } @@ -145,6 +183,17 @@ type Import struct { Name string `yaml:"name"` } +// Validate validates the import configuration is well-formed. +func (imp *Import) Validate() error { + if imp == nil { + return errors.New("invalid import: nil") + } + if imp.Name == "" { + return errors.New("invalid import: missing type name") + } + return nil +} + // NewVariable returns a serializable variable from a name and type definition func NewVariable(name string, t *TypeDesc) *Variable { return &Variable{Name: name, TypeDesc: t} @@ -165,39 +214,47 @@ type Variable struct { *TypeDesc `yaml:",inline"` } +// Validate validates the variable configuration is well-formed. +func (v *Variable) Validate() error { + if v == nil { + return errors.New("invalid variable: nil") + } + if v.Name == "" { + return errors.New("invalid variable: missing variable name") + } + if err := v.GetType().Validate(); err != nil { + return fmt.Errorf("invalid variable %q: %w", v.Name, err) + } + return nil +} + // GetType returns the variable type description. // // Note, if both the embedded TypeDesc and the field Type are non-nil, the embedded TypeDesc will // take precedence. -func (vd *Variable) GetType() *TypeDesc { - if vd == nil { +func (v *Variable) GetType() *TypeDesc { + if v == nil { return nil } - if vd.TypeDesc != nil { - return vd.TypeDesc + if v.TypeDesc != nil { + return v.TypeDesc } - if vd.Type != nil { - return vd.Type + if v.Type != nil { + return v.Type } return nil } // AsCELVariable converts the serializable form of the Variable into a CEL environment declaration. -func (vd *Variable) AsCELVariable(tp types.Provider) (*decls.VariableDecl, error) { - if vd == nil { - return nil, errors.New("nil Variable cannot be converted to a VariableDecl") - } - if vd.Name == "" { - return nil, errors.New("invalid variable, must declare a name") +func (v *Variable) AsCELVariable(tp types.Provider) (*decls.VariableDecl, error) { + if err := v.Validate(); err != nil { + return nil, err } - if vd.GetType() != nil { - t, err := vd.GetType().AsCELType(tp) - if err != nil { - return nil, fmt.Errorf("invalid variable type for '%s': %w", vd.Name, err) - } - return decls.NewVariable(vd.Name, t), nil + t, err := v.GetType().AsCELType(tp) + if err != nil { + return nil, fmt.Errorf("invalid variable %q: %w", v.Name, err) } - return nil, fmt.Errorf("invalid variable '%s', no type specified", vd.Name) + return decls.NewVariable(v.Name, t), nil } // NewContextVariable returns a serializable context variable with a specific type name. @@ -213,8 +270,19 @@ type ContextVariable struct { TypeName string `yaml:"type_name"` } +// Validate validates the context-variable configuration is well-formed. +func (ctx *ContextVariable) Validate() error { + if ctx == nil { + return nil + } + if ctx.TypeName == "" { + return errors.New("invalid context variable: missing type name") + } + return nil +} + // NewFunction creates a serializable function and overload set. -func NewFunction(name string, overloads []*Overload) *Function { +func NewFunction(name string, overloads ...*Overload) *Function { return &Function{Name: name, Overloads: overloads} } @@ -225,23 +293,37 @@ type Function struct { Overloads []*Overload `yaml:"overloads,omitempty"` } -// AsCELFunction converts the serializable form of the Function into CEL environment declaration. -func (fn *Function) AsCELFunction(tp types.Provider) (*decls.FunctionDecl, error) { +// Validate validates the function configuration is well-formed. +func (fn *Function) Validate() error { if fn == nil { - return nil, errors.New("nil Function cannot be converted to a FunctionDecl") + return errors.New("invalid function: nil") } if fn.Name == "" { - return nil, errors.New("invalid function, must declare a name") + return errors.New("invalid function: missing function name") } if len(fn.Overloads) == 0 { - return nil, fmt.Errorf("invalid function %s, must declare an overload", fn.Name) + return fmt.Errorf("invalid function %q: missing overloads", fn.Name) + } + var errs []error + for _, o := range fn.Overloads { + if err := o.Validate(); err != nil { + errs = append(errs, fmt.Errorf("invalid function %q: %w", fn.Name, err)) + } + } + return errors.Join(errs...) +} + +// AsCELFunction converts the serializable form of the Function into CEL environment declaration. +func (fn *Function) AsCELFunction(tp types.Provider) (*decls.FunctionDecl, error) { + if err := fn.Validate(); err != nil { + return nil, err } - overloads := make([]decls.FunctionOpt, len(fn.Overloads)) var err error + overloads := make([]decls.FunctionOpt, len(fn.Overloads)) for i, o := range fn.Overloads { overloads[i], err = o.AsFunctionOption(tp) if err != nil { - return nil, err + return nil, fmt.Errorf("invalid function %q: %w", fn.Name, err) } } return decls.NewFunction(fn.Name, overloads...) @@ -266,34 +348,60 @@ type Overload struct { Return *TypeDesc `yaml:"return,omitempty"` } +// Validate validates the overload configuration is well-formed. +func (od *Overload) Validate() error { + if od == nil { + return errors.New("invalid overload: nil") + } + if od.ID == "" { + return errors.New("invalid overload: missing overload id") + } + var errs []error + if od.Target != nil { + if err := od.Target.Validate(); err != nil { + errs = append(errs, fmt.Errorf("invalid overload %q target: %w", od.ID, err)) + } + } + for i, arg := range od.Args { + if err := arg.Validate(); err != nil { + errs = append(errs, fmt.Errorf("invalid overload %q arg[%d]: %w", od.ID, i, err)) + } + } + if err := od.Return.Validate(); err != nil { + errs = append(errs, fmt.Errorf("invalid overload %q return: %w", od.ID, err)) + } + return errors.Join(errs...) +} + // AsFunctionOption converts the serializable form of the Overload into a function declaration option. func (od *Overload) AsFunctionOption(tp types.Provider) (decls.FunctionOpt, error) { - if od == nil { - return nil, errors.New("nil Overload cannot be converted to a FunctionOpt") + if err := od.Validate(); err != nil { + return nil, err } args := make([]*types.Type, len(od.Args)) var err error + var errs []error for i, a := range od.Args { args[i], err = a.AsCELType(tp) if err != nil { - return nil, err + errs = append(errs, err) } } - if od.Return == nil { - return nil, fmt.Errorf("missing return type on overload: %v", od.ID) - } result, err := od.Return.AsCELType(tp) if err != nil { - return nil, err + errs = append(errs, err) } if od.Target != nil { t, err := od.Target.AsCELType(tp) if err != nil { - return nil, err + return nil, errors.Join(append(errs, err)...) } args = append([]*types.Type{t}, args...) return decls.MemberOverload(od.ID, args, result), nil } + if len(errs) != 0 { + return nil, errors.Join(errs...) + } return decls.Overload(od.ID, args, result), nil } @@ -318,20 +426,29 @@ type Extension struct { Version string `yaml:"version,omitempty"` } +// Validate validates the extension configuration is well-formed. +func (e *Extension) Validate() error { + _, err := e.GetVersion() + return err +} + // GetVersion returns the parsed version string, or an error if the version cannot be parsed. func (e *Extension) GetVersion() (uint32, error) { if e == nil { - return 0, errors.New("nil Extension cannot produce a version") + return 0, fmt.Errorf("invalid extension: nil") + } + if e.Name == "" { + return 0, fmt.Errorf("invalid extension: missing name") } if e.Version == "latest" { return math.MaxUint32, nil } if e.Version == "" { - return uint32(0), nil + return 0, nil } ver, err := strconv.ParseUint(e.Version, 10, 32) if err != nil { - return 0, fmt.Errorf("error parsing uint version: %w", err) + return 0, fmt.Errorf("invalid extension %q version: %w", e.Name, err) } return uint32(ver), nil } @@ -369,6 +486,24 @@ type LibrarySubset struct { ExcludeFunctions []*Function `yaml:"exclude_functions,omitempty"` } +// Validate validates the library configuration is well-formed. +// +// For example, setting both the IncludeMacros and ExcludeMacros together could be confusing +// and create a broken expectation, likewise for IncludeFunctions and ExcludeFunctions. +func (lib *LibrarySubset) Validate() error { + if lib == nil { + return nil + } + var errs []error + if len(lib.IncludeMacros) != 0 && len(lib.ExcludeMacros) != 0 { + errs = append(errs, errors.New("invalid subset: cannot both include and exclude macros")) + } + if len(lib.IncludeFunctions) != 0 && len(lib.ExcludeFunctions) != 0 { + errs = append(errs, errors.New("invalid subset: cannot both include and exclude functions")) + } + return errors.Join(errs...) +} + // SubsetFunction produces a function declaration which matches the supported subset, or nil // if the function is not supported by the LibrarySubset. // @@ -520,49 +655,74 @@ func (td *TypeDesc) String() string { return typeName } -// AsCELType converts the serializable object to a *types.Type value. -func (td *TypeDesc) AsCELType(tp types.Provider) (*types.Type, error) { +// Validate validates the type configuration is well-formed. +func (td *TypeDesc) Validate() error { if td == nil { - return nil, errors.New("nil TypeDesc cannot be converted to a Type instance") + return errors.New("invalid type: nil") } if td.TypeName == "" { - return nil, errors.New("invalid type description, declare a type name") + return errors.New("invalid type: missing type name") + } + if td.IsTypeParam && len(td.Params) != 0 { + return errors.New("invalid type: param type cannot have parameters") + } + switch td.TypeName { + case "list": + if len(td.Params) != 1 { + return fmt.Errorf("invalid type: list expects 1 parameter, got %d", len(td.Params)) + } + return td.Params[0].Validate() + case "map": + if len(td.Params) != 2 { + return fmt.Errorf("invalid type: map expects 2 parameters, got %d", len(td.Params)) + } + if err := td.Params[0].Validate(); err != nil { + return err + } + if err := td.Params[1].Validate(); err != nil { + return err + } + case "optional_type": + if len(td.Params) != 1 { + return fmt.Errorf("invalid type: optional_type expects 1 parameter, got %d", len(td.Params)) + } + return td.Params[0].Validate() + default: + } + return nil +} + +// AsCELType converts the serializable object to a *types.Type value. +func (td *TypeDesc) AsCELType(tp types.Provider) (*types.Type, error) { + err := td.Validate() + if err != nil { + return nil, err } - var err error switch td.TypeName { case "dyn": return types.DynType, nil case "map": - if len(td.Params) == 2 { - kt, err := td.Params[0].AsCELType(tp) - if err != nil { - return nil, err - } - vt, err := td.Params[1].AsCELType(tp) - if err != nil { - return nil, err - } - return types.NewMapType(kt, vt), nil + kt, err := td.Params[0].AsCELType(tp) + if err != nil { + return nil, err + } + vt, err := td.Params[1].AsCELType(tp) + if err != nil { + return nil, err } - return nil, fmt.Errorf("map type has unexpected param count: %d", len(td.Params)) + return types.NewMapType(kt, vt), nil case "list": - if len(td.Params) == 1 { - et, err := td.Params[0].AsCELType(tp) - if err != nil { - return nil, err - } - return types.NewListType(et), nil + et, err := td.Params[0].AsCELType(tp) + if err != nil { + return nil, err } - return nil, fmt.Errorf("list type has unexpected param count: %d", len(td.Params)) + return types.NewListType(et), nil case "optional_type": - if len(td.Params) == 1 { - et, err := td.Params[0].AsCELType(tp) - if err != nil { - return nil, err - } - return types.NewOptionalType(et), nil + et, err := td.Params[0].AsCELType(tp) + if err != nil { + return nil, err } - return nil, fmt.Errorf("optional_type has unexpected param count: %d", len(td.Params)) + return types.NewOptionalType(et), nil default: if td.IsTypeParam { return types.NewTypeParamType(td.TypeName), nil @@ -573,7 +733,7 @@ func (td *TypeDesc) AsCELType(tp types.Provider) (*types.Type, error) { } t, found := tp.FindIdent(td.TypeName) if !found { - return nil, fmt.Errorf("undefined type name: %v", td.TypeName) + return nil, fmt.Errorf("undefined type name: %q", td.TypeName) } _, ok := t.(*types.Type) if ok && len(td.Params) == 0 { diff --git a/common/env/env_test.go b/common/env/env_test.go index 157d5d35b..1a18038a8 100644 --- a/common/env/env_test.go +++ b/common/env/env_test.go @@ -55,7 +55,7 @@ func TestConfig(t *testing.T) { AddExtensions(NewExtension("optional", math.MaxUint32), NewExtension("strings", 1)). SetContextVariable(NewContextVariable("google.expr.proto3.test.TestAllTypes")). AddFunctions( - NewFunction("coalesce", []*Overload{ + NewFunction("coalesce", NewOverload("coalesce_wrapped_int", []*TypeDesc{NewTypeDesc("google.protobuf.Int64Value"), NewTypeDesc("int")}, NewTypeDesc("int")), @@ -65,7 +65,7 @@ func TestConfig(t *testing.T) { NewOverload("coalesce_wrapped_uint", []*TypeDesc{NewTypeDesc("google.protobuf.UInt64Value"), NewTypeDesc("uint")}, NewTypeDesc("uint")), - }), + ), ), }, { @@ -78,14 +78,14 @@ func TestConfig(t *testing.T) { ).AddVariables( NewVariable("msg", NewTypeDesc("google.expr.proto3.test.TestAllTypes")), ).AddFunctions( - NewFunction("isEmpty", []*Overload{ + NewFunction("isEmpty", NewMemberOverload("wrapper_string_isEmpty", NewTypeDesc("google.protobuf.StringValue"), nil, NewTypeDesc("bool")), NewMemberOverload("list_isEmpty", NewTypeDesc("list", NewTypeParam("T")), nil, NewTypeDesc("bool")), - }), + ), ), }, { @@ -124,6 +124,9 @@ func TestConfig(t *testing.T) { t.Fatalf("os.ReadFile(%q) failed: %v", fileName, err) } got := unmarshalYAML(t, data) + if err := got.Validate(); err != nil { + t.Errorf("Validate() got %v, wanted nil error", err) + } if got.Container != tc.want.Container { t.Errorf("Container got %s, wanted %s", got.Container, tc.want.Container) } @@ -169,17 +172,64 @@ func TestConfig(t *testing.T) { } } -func TestNewImport(t *testing.T) { - imp := NewImport("qualified.type.name") - if imp.Name != "qualified.type.name" { - t.Errorf("NewImport() got name: %s, wanted %s", imp.Name, "qualified.type.name") +func TestConfigValidateErrors(t *testing.T) { + tests := []struct { + name string + in *Config + want error + }{ + { + name: "nil config valid", + }, + { + name: "invalid import", + in: NewConfig("invalid import").AddImports(NewImport("")), + want: errors.New("invalid import"), + }, + { + name: "invalid subset", + in: NewConfig("invalid subset").SetStdLib(NewLibrarySubset().AddExcludedMacros("has").AddIncludedMacros("exists")), + want: errors.New("invalid subset"), + }, + { + name: "invalid extension", + in: NewConfig("invalid extension").AddExtensions(NewExtension("", 0)), + want: errors.New("invalid extension"), + }, + { + name: "invalid context variable", + in: NewConfig("invalid context variable").SetContextVariable(NewContextVariable("")), + want: errors.New("invalid context variable"), + }, + { + name: "invalid variable", + in: NewConfig("invalid variable").AddVariables(NewVariable("", nil)), + want: errors.New("invalid variable"), + }, + { + name: "invalid function", + in: NewConfig("invalid function").AddFunctions(NewFunction("", nil)), + want: errors.New("invalid function"), + }, } -} -func TestNewContextVariable(t *testing.T) { - ctx := NewContextVariable("qualified.type.name") - if ctx.TypeName != "qualified.type.name" { - t.Errorf("NewContextVariable() got name: %s, wanted %s", ctx.TypeName, "qualified.type.name") + for _, tst := range tests { + tc := tst + t.Run(tc.name, func(t *testing.T) { + err := tc.in.Validate() + if err == nil && tc.want == nil { + return + } + if err == nil && tc.want != nil { + t.Fatalf("config.Validate() got valid, wanted error %v", tc.want) + } + if err != nil && tc.want == nil { + t.Fatalf("config.Validate() got error %v, wanted nil error", err) + } + if !strings.Contains(err.Error(), tc.want.Error()) { + t.Errorf("config.Validate() got error %v, wanted %v", err, tc.want) + } + }) } } @@ -247,18 +297,18 @@ func TestConfigAddFunctionDecls(t *testing.T) { in: mustNewFunction(t, "size", decls.Overload("size_string", []*types.Type{types.StringType}, types.IntType), ), - out: NewFunction("size", []*Overload{ + out: NewFunction("size", NewOverload("size_string", []*TypeDesc{NewTypeDesc("string")}, NewTypeDesc("int")), - }), + ), }, { name: "global function decl - nullable arg", in: mustNewFunction(t, "size", decls.Overload("size_wrapper_string", []*types.Type{types.NewNullableType(types.StringType)}, types.IntType), ), - out: NewFunction("size", []*Overload{ + out: NewFunction("size", NewOverload("size_wrapper_string", []*TypeDesc{NewTypeDesc("google.protobuf.StringValue")}, NewTypeDesc("int")), - }), + ), }, { name: "member function decl - nullable arg", @@ -266,10 +316,10 @@ func TestConfigAddFunctionDecls(t *testing.T) { decls.MemberOverload("list_size", []*types.Type{types.NewListType(types.NewTypeParamType("T"))}, types.IntType), decls.MemberOverload("string_size", []*types.Type{types.StringType}, types.IntType), ), - out: NewFunction("size", []*Overload{ + out: NewFunction("size", NewMemberOverload("list_size", NewTypeDesc("list", NewTypeParam("T")), []*TypeDesc{}, NewTypeDesc("int")), NewMemberOverload("string_size", NewTypeDesc("string"), []*TypeDesc{}, NewTypeDesc("int")), - }), + ), }, } for _, tst := range tests { @@ -286,6 +336,42 @@ func TestConfigAddFunctionDecls(t *testing.T) { } } +func TestNewImport(t *testing.T) { + imp := NewImport("qualified.type.name") + if imp.Name != "qualified.type.name" { + t.Errorf("NewImport() got name: %s, wanted %s", imp.Name, "qualified.type.name") + } +} + +func TestImportValidate(t *testing.T) { + var imp *Import + err := imp.Validate() + if err == nil || !strings.Contains(err.Error(), "invalid import") { + t.Errorf("imp.Validate() got %v, wanted error 'invalid import'", err) + } + + imp = NewImport("") + err = imp.Validate() + if err == nil || !strings.Contains(err.Error(), "invalid import") { + t.Errorf("imp.Validate() got %v, wanted error 'invalid import'", err) + } +} + +func TestNewContextVariable(t *testing.T) { + ctx := NewContextVariable("qualified.type.name") + if ctx.TypeName != "qualified.type.name" { + t.Errorf("NewContextVariable() got name: %s, wanted %s", ctx.TypeName, "qualified.type.name") + } +} + +func TestContextVariableValidate(t *testing.T) { + ctx := NewContextVariable("") + err := ctx.Validate() + if err == nil || !strings.Contains(err.Error(), "invalid context variable") { + t.Errorf("ctx.Validate() got %v, wanted error 'invalid context variable'", err) + } +} + func TestVariableGetType(t *testing.T) { tests := []struct { name string @@ -341,7 +427,7 @@ func TestVariableAsCELVariable(t *testing.T) { { name: "nil-safety check", v: nil, - want: errors.New("nil Variable"), + want: errors.New("invalid variable: nil"), }, { name: "no variable name", @@ -353,7 +439,7 @@ func TestVariableAsCELVariable(t *testing.T) { v: &Variable{ Name: "hello", }, - want: errors.New("no type specified"), + want: errors.New("invalid type: nil"), }, { name: "bad type", @@ -361,7 +447,15 @@ func TestVariableAsCELVariable(t *testing.T) { Name: "hello", TypeDesc: &TypeDesc{}, }, - want: errors.New("declare a type name"), + want: errors.New("missing type name"), + }, + { + name: "undefined type", + v: &Variable{ + Name: "hello", + TypeDesc: &TypeDesc{TypeName: "undefined"}, + }, + want: errors.New("undefined type name"), }, { name: "int type", @@ -449,7 +543,7 @@ func TestVariableAsCELVariable(t *testing.T) { t.Fatalf("AsCELVariable() got error %v, wanted %v", err, tc.want) } if !strings.Contains(err.Error(), wantErr.Error()) { - t.Fatalf("AsCELVariable() got error %v, wanted error contining %v", err, wantErr) + t.Fatalf("AsCELVariable() got error %v, wanted error containing %v", err, wantErr) } return } @@ -485,44 +579,70 @@ func TestFunctionAsCELFunction(t *testing.T) { { name: "nil function", f: nil, - want: errors.New("nil Function"), + want: errors.New("invalid function: nil"), }, { name: "unnamed function", f: &Function{}, - want: errors.New("must declare a name"), + want: errors.New("invalid function"), }, { name: "no overloads", - f: NewFunction("no_overloads", []*Overload{}), - want: errors.New("must declare an overload"), + f: NewFunction("no_overloads"), + want: errors.New("missing overloads"), }, { name: "nil overload", - f: NewFunction("no_overloads", []*Overload{nil}), - want: errors.New("nil Overload"), + f: NewFunction("no_overloads", nil), + want: errors.New("invalid overload: nil"), + }, + { + name: "missing overload id", + f: NewFunction("size", &Overload{}), + want: errors.New("missing overload id"), }, { name: "no return type", - f: NewFunction("size", []*Overload{ + f: NewFunction("size", NewOverload("size_string", []*TypeDesc{NewTypeDesc("string")}, nil), - }), - want: errors.New("missing return type"), + ), + want: errors.New("return: invalid type"), }, { name: "bad return type", - f: NewFunction("size", []*Overload{ + f: NewFunction("size", NewOverload("size_string", []*TypeDesc{NewTypeDesc("string")}, NewTypeDesc("")), - }), + ), want: errors.New("invalid type"), }, { name: "bad arg type", - f: NewFunction("size", []*Overload{ + f: NewFunction("size", NewOverload("size_string", []*TypeDesc{NewTypeDesc("")}, NewTypeDesc("")), - }), + ), want: errors.New("invalid type"), }, + { + name: "undefined arg type", + f: NewFunction("size", + NewOverload("size_undefined", []*TypeDesc{NewTypeDesc("undefined")}, NewTypeDesc("int")), + ), + want: errors.New("undefined type"), + }, + { + name: "undefined return type", + f: NewFunction("size", + NewOverload("size_undefined", []*TypeDesc{NewTypeDesc("string")}, NewTypeDesc("undefined")), + ), + want: errors.New("undefined type"), + }, + { + name: "undefined target type", + f: NewFunction("size", + NewMemberOverload("size_undefined", NewTypeDesc("undefined"), []*TypeDesc{NewTypeDesc("string")}, NewTypeDesc("int")), + ), + want: errors.New("undefined type"), + }, { name: "bad target type", f: &Function{Name: "size", @@ -574,7 +694,7 @@ func TestFunctionAsCELFunction(t *testing.T) { t.Fatalf("AsCELFunction() got error %v, wanted %v", err, tc.want) } if !strings.Contains(err.Error(), wantErr.Error()) { - t.Fatalf("AsCELFunction() got error %v, wanted error contining %v", err, wantErr) + t.Fatalf("AsCELFunction() got error %v, wanted error containing %v", err, wantErr) } return } @@ -592,37 +712,52 @@ func TestTypeDescAsCELTypeErrors(t *testing.T) { { name: "nil-safety check", t: nil, - want: errors.New("nil TypeDesc"), + want: errors.New("invalid type: nil"), }, { name: "no type name", t: &TypeDesc{}, - want: errors.New("invalid type"), + want: errors.New("missing type name"), }, { name: "invalid optional_type", t: &TypeDesc{TypeName: "optional_type"}, - want: errors.New("unexpected param count"), + want: errors.New("expects 1 parameter"), }, { name: "invalid optional param type", t: &TypeDesc{TypeName: "optional_type", Params: []*TypeDesc{{}}}, want: errors.New("invalid type"), }, + { + name: "undefined optional param type", + t: &TypeDesc{TypeName: "optional_type", Params: []*TypeDesc{{TypeName: "undefined"}}}, + want: errors.New("undefined type"), + }, + { + name: "invalid param type", + t: &TypeDesc{TypeName: "T", IsTypeParam: true, Params: []*TypeDesc{{TypeName: "string"}}}, + want: errors.New("invalid type: param type"), + }, { name: "invalid list", t: &TypeDesc{TypeName: "list"}, - want: errors.New("unexpected param count"), + want: errors.New("expects 1 parameter"), }, { name: "invalid list param type", t: &TypeDesc{TypeName: "list", Params: []*TypeDesc{{}}}, want: errors.New("invalid type"), }, + { + name: "undefined list param type", + t: &TypeDesc{TypeName: "list", Params: []*TypeDesc{{TypeName: "undefined"}}}, + want: errors.New("undefined type name"), + }, { name: "invalid map", t: &TypeDesc{TypeName: "map"}, - want: errors.New("unexpected param count"), + want: errors.New("expects 2 parameters"), }, { name: "invalid map key type", @@ -634,6 +769,16 @@ func TestTypeDescAsCELTypeErrors(t *testing.T) { t: &TypeDesc{TypeName: "map", Params: []*TypeDesc{{TypeName: "string"}, {}}}, want: errors.New("invalid type"), }, + { + name: "undefined map key type", + t: &TypeDesc{TypeName: "map", Params: []*TypeDesc{{TypeName: "undefined"}, {TypeName: "undefined"}}}, + want: errors.New("undefined type name"), + }, + { + name: "undefined map value type", + t: &TypeDesc{TypeName: "map", Params: []*TypeDesc{{TypeName: "string"}, {TypeName: "undefined"}}}, + want: errors.New("undefined type name"), + }, { name: "invalid set", t: &TypeDesc{TypeName: "set", Params: []*TypeDesc{{}}}, @@ -641,7 +786,7 @@ func TestTypeDescAsCELTypeErrors(t *testing.T) { }, { name: "undefined type identifier", - t: &TypeDesc{TypeName: "vector"}, + t: &TypeDesc{TypeName: "undefined"}, want: errors.New("undefined type"), }, } @@ -660,7 +805,7 @@ func TestTypeDescAsCELTypeErrors(t *testing.T) { t.Fatalf("AsCELType() got error %v, wanted %v", err, tc.want) } if !strings.Contains(err.Error(), wantErr.Error()) { - t.Fatalf("AsCELType() got error %v, wanted error contining %v", err, wantErr) + t.Fatalf("AsCELType() got error %v, wanted error containing %v", err, wantErr) } return } @@ -671,6 +816,72 @@ func TestTypeDescAsCELTypeErrors(t *testing.T) { } } +func TestLibrarySubsetValidate(t *testing.T) { + tests := []struct { + name string + lib *LibrarySubset + want error + }{ + { + name: "nil library", + lib: NewLibrarySubset(), + }, + { + name: "empty library", + lib: NewLibrarySubset(), + }, + { + name: "only excluded funcs", + lib: NewLibrarySubset().AddExcludedFunctions(NewFunction("size", nil)), + }, + { + name: "only included funcs", + lib: NewLibrarySubset().AddIncludedFunctions(NewFunction("size", nil)), + }, + { + name: "only excluded macros", + lib: NewLibrarySubset().AddExcludedMacros("has"), + }, + { + name: "only included macros", + lib: NewLibrarySubset().AddIncludedMacros("exists"), + }, + { + name: "both included and excluded funcs", + lib: NewLibrarySubset(). + AddIncludedFunctions(NewFunction("size", nil)). + AddExcludedFunctions(NewFunction("size", nil)), + want: errors.New("invalid subset"), + }, + { + name: "both included and excluded macros", + lib: NewLibrarySubset(). + AddIncludedMacros("has"). + AddExcludedMacros("exists"), + want: errors.New("invalid subset"), + }, + } + + for _, tst := range tests { + tc := tst + t.Run(tc.name, func(t *testing.T) { + err := tc.lib.Validate() + if err == nil && tc.want == nil { + return + } + if err == nil && tc.want != nil { + t.Fatalf("lib.Validate() got valid, wanted error %v", tc.want) + } + if err != nil && tc.want == nil { + t.Fatalf("lib.Validate() got error %v, wanted nil error", err) + } + if !strings.Contains(err.Error(), tc.want.Error()) { + t.Errorf("lib.Validate() got error %v, wanted %v", err, tc.want) + } + }) + } +} + func TestSubsetFunction(t *testing.T) { tests := []struct { name string @@ -874,26 +1085,31 @@ func TestExtensionGetVersion(t *testing.T) { }{ { name: "nil extension", - want: errors.New("nil Extension"), + want: errors.New("invalid extension: nil"), }, { - name: "unset version", + name: "missing name", ext: &Extension{}, + want: errors.New("missing name"), + }, + { + name: "unset version", + ext: &Extension{Name: "test"}, want: uint32(0), }, { name: "numeric version", - ext: &Extension{Version: "1"}, + ext: &Extension{Name: "test", Version: "1"}, want: uint32(1), }, { name: "latest version", - ext: &Extension{Version: "latest"}, + ext: &Extension{Name: "test", Version: "latest"}, want: uint32(math.MaxUint32), }, { name: "bad version", - ext: &Extension{Version: "1.0"}, + ext: &Extension{Name: "test", Version: "1.0"}, want: errors.New("invalid syntax"), }, } @@ -907,7 +1123,7 @@ func TestExtensionGetVersion(t *testing.T) { t.Fatalf("GetVersion() got error %v, wanted %v", err, tc.want) } if !strings.Contains(err.Error(), wantErr.Error()) { - t.Fatalf("GetVersion() got error %v, wanted error contining %v", err, wantErr) + t.Fatalf("GetVersion() got error %v, wanted error containing %v", err, wantErr) } return } diff --git a/policy/compiler_test.go b/policy/compiler_test.go index c85f8702e..cc3e80f49 100644 --- a/policy/compiler_test.go +++ b/policy/compiler_test.go @@ -158,7 +158,7 @@ func compile(t testing.TB, name string, parseOpts []ParserOption, envOpts []cel. if policy.name.Value != name { t.Errorf("policy name is %v, wanted %q", policy.name, name) } - env, err := cel.NewEnv( + env, err := cel.NewCustomEnv( cel.OptionalTypes(), cel.EnableMacroCallTracking(), cel.ExtendedValidations(), @@ -172,11 +172,7 @@ func compile(t testing.TB, name string, parseOpts []ParserOption, envOpts []cel. t.Fatalf("env.Extend() with env options %v, failed: %v", config, err) } // Configure declarations - configOpts, err := config.AsEnvOptions(env.CELTypeProvider()) - if err != nil { - t.Fatalf("config.AsEnvOptions() failed: %v", err) - } - env, err = env.Extend(configOpts...) + env, err = env.Extend(FromConfig(config)) if err != nil { t.Fatalf("env.Extend() with config options %v, failed: %v", config, err) } diff --git a/policy/config.go b/policy/config.go index 0501709b6..0bd7902c1 100644 --- a/policy/config.go +++ b/policy/config.go @@ -15,123 +15,40 @@ package policy import ( - "errors" "fmt" - "google.golang.org/protobuf/proto" - "github.com/google/cel-go/cel" - "github.com/google/cel-go/common/decls" "github.com/google/cel-go/common/env" - "github.com/google/cel-go/common/types" - "github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/ext" ) -// NewConfig returns a YAML serializable policy environment. -func NewConfig(e *env.Config) *Config { - return &Config{Config: e} -} - -// Config represents a YAML serializable CEL environment configuration. -type Config struct { - *env.Config +// FromConfig configures a CEL policy environment from a config file. +// +// This option supports all extensions supported by policies, whereas the cel.FromConfig supports +// a set of configuration ConfigOptionFactory values to handle extensions and other config features +// which may be defined outside of the `cel` package. +func FromConfig(config *env.Config) cel.EnvOption { + return cel.FromConfig(config, extensionOptionFactory) } -// AsEnvOptions converts the Config value to a collection of cel environment options. -func (c *Config) AsEnvOptions(provider types.Provider) ([]cel.EnvOption, error) { - envOpts := []cel.EnvOption{} - // Configure the standard lib subset. - if c.StdLib != nil { - if c.StdLib.Disabled { - envOpts = append(envOpts, func(e *cel.Env) (*cel.Env, error) { - if !e.HasLibrary("cel.lib.std") { - return e, nil - } - return cel.NewCustomEnv() - }) - } else { - envOpts = append(envOpts, func(e *cel.Env) (*cel.Env, error) { - return cel.NewCustomEnv(cel.StdLib(cel.StdLibSubset(c.StdLib))) - }) - } - } - - // Configure the container - if c.Container != "" { - envOpts = append(envOpts, cel.Container(c.Container)) - } - - // Configure abbreviations - for _, imp := range c.Imports { - envOpts = append(envOpts, cel.Abbrevs(imp.Name)) +// extensionOptionFactory converts an ExtensionConfig value to a CEL environment option. +func extensionOptionFactory(configElement any) (cel.EnvOption, bool) { + ext, isExtension := configElement.(*env.Extension) + if !isExtension { + return nil, false } - - // Configure the context variable declaration - if c.ContextVariable != nil { - if len(c.Variables) > 0 { - return nil, errors.New("either the context_variable or the variables may be set, but not both") - } - typeName := c.ContextVariable.TypeName - if typeName == "" { - return nil, errors.New("invalid context variable, must set type name field") - } - if _, found := provider.FindStructType(typeName); !found { - return nil, fmt.Errorf("could not find context proto type name: %s", typeName) - } - // Attempt to instantiate the proto in order to reflect to its descriptor - msg := provider.NewValue(typeName, map[string]ref.Val{}) - pbMsg, ok := msg.Value().(proto.Message) - if !ok { - return nil, fmt.Errorf("type name was not a protobuf: %T", msg.Value()) - } - envOpts = append(envOpts, cel.DeclareContextProto(pbMsg.ProtoReflect().Descriptor())) - } - - if len(c.Variables) != 0 { - vars := make([]*decls.VariableDecl, 0, len(c.Variables)) - for _, v := range c.Variables { - vDef, err := v.AsCELVariable(provider) - if err != nil { - return nil, err - } - vars = append(vars, vDef) - } - envOpts = append(envOpts, cel.VariableDecls(vars...)) - } - if len(c.Functions) != 0 { - funcs := make([]*decls.FunctionDecl, 0, len(c.Functions)) - for _, f := range c.Functions { - fnDef, err := f.AsCELFunction(provider) - if err != nil { - return nil, err - } - funcs = append(funcs, fnDef) - } - envOpts = append(envOpts, cel.FunctionDecls(funcs...)) - } - for _, e := range c.Extensions { - opt, err := extensionEnvOption(e) - if err != nil { - return nil, err - } - envOpts = append(envOpts, opt) - } - return envOpts, nil -} - -// extensionEnvOption converts an ExtensionConfig value to a CEL environment option. -func extensionEnvOption(ec *env.Extension) (cel.EnvOption, error) { - fac, found := extFactories[ec.Name] + fac, found := extFactories[ext.Name] if !found { - return nil, fmt.Errorf("unrecognized extension: %s", ec.Name) + return nil, false } // If the version is 'latest', set the version value to the max uint. - ver, err := ec.GetVersion() + ver, err := ext.GetVersion() if err != nil { - return nil, err + return func(*cel.Env) (*cel.Env, error) { + return nil, fmt.Errorf("invalid extension version: %s - %s", ext.Name, ext.Version) + }, true } - return fac(ver), nil + return fac(ver), true } // extensionFactory accepts a version and produces a CEL environment associated with the versioned extension. diff --git a/policy/config_test.go b/policy/config_test.go index df28990a7..c61fe58d6 100644 --- a/policy/config_test.go +++ b/policy/config_test.go @@ -103,7 +103,7 @@ variables: } for _, tst := range tests { c := parseConfigYaml(t, tst) - _, err := c.AsEnvOptions(baseEnv.CELTypeProvider()) + _, err := baseEnv.Extend(FromConfig(c)) if err != nil { t.Errorf("AsEnvOptions() generated error: %v", err) } @@ -127,7 +127,7 @@ variables: - name: "bad_type" type: type_name: "strings"`, - err: "invalid variable type for 'bad_type': undefined type name: strings", + err: `invalid variable "bad_type": undefined type name: "strings"`, }, { config: ` @@ -135,7 +135,7 @@ variables: - name: "bad_list" type: type_name: "list"`, - err: "invalid variable type for 'bad_list': list type has unexpected param count: 0", + err: `invalid variable "bad_list": invalid type: list expects 1 parameter, got 0`, }, { config: ` @@ -145,7 +145,7 @@ variables: type_name: "map" params: - type_name: "string"`, - err: "invalid variable type for 'bad_map': map type has unexpected param count: 1", + err: `invalid variable "bad_map": invalid type: map expects 2 parameters, got 1`, }, { config: ` @@ -155,7 +155,7 @@ variables: type_name: "list" params: - type_name: "number"`, - err: "invalid variable type for 'bad_list_type_param': undefined type name: number", + err: `invalid variable "bad_list_type_param": undefined type name: "number"`, }, { config: ` @@ -166,21 +166,21 @@ variables: params: - type_name: "string" - type_name: "invalid_opaque_type"`, - err: "invalid variable type for 'bad_map_type_param': undefined type name: invalid_opaque_type", + err: `invalid variable "bad_map_type_param": undefined type name: "invalid_opaque_type"`, }, { config: ` context_variable: type_name: "bad.proto.MessageType" `, - err: "could not find context proto type name: bad.proto.MessageType", + err: `invalid context proto type: "bad.proto.MessageType"`, }, { config: ` variables: - type: type_name: "no variable name"`, - err: "invalid variable, must declare a name", + err: "invalid variable: missing variable name", }, { @@ -191,7 +191,7 @@ functions: - id: "zero_arity" return: type_name: "mystery"`, - err: "undefined type name: mystery", + err: `invalid function "bad_return": undefined type name: "mystery"`, }, { config: ` @@ -203,7 +203,7 @@ functions: type_name: "unknown" return: type_name: "null_type"`, - err: "undefined type name: unknown", + err: `invalid function "bad_target": undefined type name: "unknown"`, }, { config: ` @@ -215,7 +215,7 @@ functions: - type_name: "unknown" return: type_name: "null_type"`, - err: "undefined type name: unknown", + err: `invalid function "bad_arg": undefined type name: "unknown"`, }, { config: ` @@ -225,7 +225,7 @@ functions: - id: "unary_global" args: - type_name: "null_type"`, - err: "missing return type on overload: unary_global", + err: `invalid function "missing_return": invalid overload "unary_global" return: invalid type: nil`, }, } baseEnv, err := cel.NewEnv(cel.OptionalTypes()) @@ -234,17 +234,17 @@ functions: } for _, tst := range tests { c := parseConfigYaml(t, tst.config) - _, err := c.AsEnvOptions(baseEnv.CELTypeProvider()) + _, err := baseEnv.Extend(FromConfig(c)) if err == nil || err.Error() != tst.err { t.Errorf("AsEnvOptions() got error: %v, wanted %s", err, tst.err) } } } -func parseConfigYaml(t *testing.T, doc string) *Config { +func parseConfigYaml(t *testing.T, doc string) *env.Config { config := &env.Config{} if err := yaml.Unmarshal([]byte(doc), config); err != nil { t.Fatalf("yaml.Unmarshal(%q) failed: %v", doc, err) } - return NewConfig(config) + return config } diff --git a/policy/helper_test.go b/policy/helper_test.go index ba612fc08..3934301fe 100644 --- a/policy/helper_test.go +++ b/policy/helper_test.go @@ -278,7 +278,7 @@ func readPolicy(t testing.TB, fileName string) *Source { return ByteSource(policyBytes, fileName) } -func readPolicyConfig(t testing.TB, fileName string) *Config { +func readPolicyConfig(t testing.TB, fileName string) *env.Config { t.Helper() testCaseBytes, err := os.ReadFile(fileName) if err != nil { @@ -289,7 +289,7 @@ func readPolicyConfig(t testing.TB, fileName string) *Config { if err != nil { log.Fatalf("yaml.Unmarshal(%s) error: %v", fileName, err) } - return NewConfig(config) + return config } func readTestSuite(t testing.TB, fileName string) *TestSuite { From 45c4980b2ce9d04234597eb2771258d7f1bb10ca Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Tue, 18 Feb 2025 16:48:30 -0800 Subject: [PATCH 13/46] Support for feature flags and validators in env.Config (#1132) * Support for feature flags and validators in env.Config * Minor update to documentation --- cel/env.go | 26 +++- cel/env_test.go | 215 ++++++++++++++++++++++++-- cel/options.go | 92 +++++++++-- cel/validator.go | 70 ++++++++- common/env/env.go | 92 ++++++++++- common/env/env_test.go | 143 ++++++++++++++++- common/env/testdata/extended_env.yaml | 10 ++ ext/bindings.go | 2 +- ext/formatting.go | 2 +- policy/config.go | 5 +- 10 files changed, 615 insertions(+), 42 deletions(-) diff --git a/cel/env.go b/cel/env.go index 5bebfbc30..3e2c62876 100644 --- a/cel/env.go +++ b/cel/env.go @@ -263,6 +263,25 @@ func (e *Env) ToConfig(name string) (*env.Config, error) { } } + // Serialize validators + for _, val := range e.Validators() { + // Only add configurable validators to the env.Config as all others are + // expected to be implicitly enabled via extension libraries. + if confVal, ok := val.(ConfigurableASTValidator); ok { + conf.AddValidators(confVal.ToConfig()) + } + } + + // Serialize features + for featID, enabled := range e.features { + featName, found := featureNameByID(featID) + if !found { + // If the feature isn't named, it isn't intended to be publicly exposed + continue + } + conf.AddFeatures(env.NewFeature(featName, enabled)) + } + return conf, nil } @@ -541,7 +560,7 @@ func (e *Env) Functions() map[string]*decls.FunctionDecl { // Variables returns the set of variables associated with the environment. func (e *Env) Variables() []*decls.VariableDecl { - return e.variables + return e.variables[:] } // HasValidator returns whether a specific ASTValidator has been configured in the environment. @@ -554,6 +573,11 @@ func (e *Env) HasValidator(name string) bool { return false } +// Validators returns the set of ASTValidators configured on the environment. +func (e *Env) Validators() []ASTValidator { + return e.validators[:] +} + // Parse parses the input expression value `txt` to a Ast and/or a set of Issues. // // This form of Parse creates a Source value for the input `txt` and forwards to the diff --git a/cel/env_test.go b/cel/env_test.go index 55e16d981..38ae5e4cd 100644 --- a/cel/env_test.go +++ b/cel/env_test.go @@ -24,6 +24,8 @@ import ( "testing" "github.com/google/cel-go/common" + "github.com/google/cel-go/common/ast" + "github.com/google/cel-go/common/decls" "github.com/google/cel-go/common/env" "github.com/google/cel-go/common/operators" "github.com/google/cel-go/common/types" @@ -401,6 +403,30 @@ func TestEnvToConfig(t *testing.T) { }, want: env.NewConfig("context proto").SetContextVariable(env.NewContextVariable("google.expr.proto3.test.TestAllTypes")), }, + { + name: "feature flags", + opts: []EnvOption{ + DefaultUTCTimeZone(false), + EnableMacroCallTracking(), + }, + want: env.NewConfig("feature flags").AddFeatures( + env.NewFeature("cel.feature.macro_call_tracking", true), + ), + }, + { + name: "validators", + opts: []EnvOption{ + ExtendedValidations(), + ASTValidators(ValidateComprehensionNestingLimit(1)), + }, + want: env.NewConfig("validators").AddValidators( + env.NewValidator("cel.validator.duration"), + env.NewValidator("cel.validator.timestamp"), + env.NewValidator("cel.validator.matches"), + env.NewValidator("cel.validator.homogeneous_literals"), + env.NewValidator("cel.validator.comprehension_nesting_limit").SetConfig(map[string]any{"limit": 1}), + ), + }, } for _, tst := range tests { @@ -430,11 +456,12 @@ func TestEnvFromConfig(t *testing.T) { out ref.Val } tests := []struct { - name string - beforeOpts []EnvOption - afterOpts []EnvOption - conf *env.Config - exprs []exprCase + name string + beforeOpts []EnvOption + afterOpts []EnvOption + conf *env.Config + confHandlers []ConfigOptionFactory + exprs []exprCase }{ { name: "std env", @@ -617,18 +644,138 @@ func TestEnvFromConfig(t *testing.T) { }, }, }, + { + name: "extensions - config factory", + conf: env.NewConfig("extensions"). + AddExtensions(env.NewExtension("plus", math.MaxUint32)), + confHandlers: []ConfigOptionFactory{ + func(a any) (EnvOption, bool) { + ext, ok := a.(*env.Extension) + if !ok || ext.Name != "plus" { + return nil, false + } + return Function("plus", Overload("plus_int_int", []*Type{IntType, IntType}, IntType, + decls.BinaryBinding(func(lhs, rhs ref.Val) ref.Val { + l := lhs.(types.Int) + r := rhs.(types.Int) + return l + r + }))), true + }, + }, + exprs: []exprCase{ + { + name: "plus", + expr: "plus(1, 2)", + out: types.Int(3), + }, + }, + }, + { + name: "features", + conf: env.NewConfig("features"). + AddVariables( + env.NewVariable("m", + env.NewTypeDesc("map", env.NewTypeDesc("string"), env.NewTypeDesc("string")))). + AddFeatures( + env.NewFeature("cel.feature.backtick_escape_syntax", true), + env.NewFeature("cel.feature.unknown_feature_name", true)), + exprs: []exprCase{ + { + name: "optional key", + expr: "m.`key-name` == 'value'", + in: map[string]any{"m": map[string]string{"key-name": "value"}}, + out: types.True, + }, + }, + }, + { + name: "validators", + conf: env.NewConfig("validators"). + AddVariables( + env.NewVariable("m", + env.NewTypeDesc("map", env.NewTypeDesc("string"), env.NewTypeDesc("string"))), + ). + AddValidators( + env.NewValidator(durationValidatorName), + env.NewValidator(timestampValidatorName), + env.NewValidator(regexValidatorName), + env.NewValidator(homogeneousValidatorName), + env.NewValidator(nestingLimitValidatorName).SetConfig(map[string]any{"limit": 0}), + ), + exprs: []exprCase{ + { + name: "bad duration", + expr: "duration('1')", + iss: errors.New("invalid duration"), + }, + { + name: "bad timestamp", + expr: "timestamp('1')", + iss: errors.New("invalid timestamp"), + }, + { + name: "bad regex", + expr: "'hello'.matches('?^()')", + iss: errors.New("invalid matches"), + }, + { + name: "mixed type list", + expr: "[1, 2.0]", + iss: errors.New("expected type 'int'"), + }, + { + name: "disabled comprehension", + expr: "[1, 2].exists(x, x % 2 == 0)", + iss: errors.New("comprehension exceeds nesting limit"), + }, + }, + }, + { + name: "validators - config factory", + conf: env.NewConfig("validators"). + AddValidators( + env.NewValidator("cel.validators.return_type").SetConfig(map[string]any{"type_name": "string"}), + ), + confHandlers: []ConfigOptionFactory{ + func(a any) (EnvOption, bool) { + val, ok := a.(*env.Validator) + if !ok || val.Name != "cel.validators.return_type" { + return nil, false + } + typeName, found := val.ConfigValue("type_name") + if !found { + return func(*Env) (*Env, error) { + return nil, fmt.Errorf("invalid validator: %s missing config parameter 'type_name'", val.Name) + }, true + } + return func(e *Env) (*Env, error) { + t, err := env.NewTypeDesc(typeName.(string)).AsCELType(e.CELTypeProvider()) + if err != nil { + return nil, err + } + return ASTValidators(returnTypeValidator{returnType: t})(e) + }, true + }, + }, + exprs: []exprCase{ + { + name: "string - ok", + expr: "'hello'", + out: types.String("hello"), + }, + { + name: "int - error", + expr: "1", + iss: errors.New("unsupported return type: int, want string"), + }, + }, + }, } for _, tst := range tests { tc := tst t.Run(tc.name, func(t *testing.T) { opts := tc.beforeOpts - opts = append(opts, FromConfig(tc.conf, func(elem any) (EnvOption, bool) { - if ext, ok := elem.(*env.Extension); ok && ext.Name == "optional" { - ver, _ := ext.GetVersion() - return OptionalTypes(OptionalTypesVersion(ver)), true - } - return nil, false - })) + opts = append(opts, FromConfig(tc.conf, tc.confHandlers...)) opts = append(opts, tc.afterOpts...) var e *Env var err error @@ -679,6 +826,16 @@ func TestEnvFromConfigErrors(t *testing.T) { conf *env.Config want error }{ + { + name: "bad container", + conf: env.NewConfig("bad container").SetContainer(".hello.world"), + want: errors.New("container name must not contain"), + }, + { + name: "colliding imports", + conf: env.NewConfig("colliding imports").AddImports(env.NewImport("pkg.ImportName"), env.NewImport("pkg2.ImportName")), + want: errors.New("abbreviation collides"), + }, { name: "invalid subset", conf: env.NewConfig("invalid subset").SetStdLib(env.NewLibrarySubset().SetDisableMacros(true)), @@ -707,9 +864,21 @@ func TestEnvFromConfigErrors(t *testing.T) { { name: "unrecognized extension", conf: env.NewConfig("unrecognized extension"). - AddExtensions(env.NewExtension("optional", math.MaxUint32)), + AddExtensions(env.NewExtension("unrecognized", math.MaxUint32)), want: errors.New("unrecognized extension"), }, + { + name: "invalid validator config", + conf: env.NewConfig("invalid validator config"). + AddValidators(env.NewValidator("cel.validator.comprehension_nesting_limit")), + want: errors.New("invalid validator"), + }, + { + name: "invalid validator config type", + conf: env.NewConfig("invalid validator config"). + AddValidators(env.NewValidator("cel.validator.comprehension_nesting_limit").SetConfig(map[string]any{"limit": 2.0})), + want: errors.New("invalid validator"), + }, } for _, tst := range tests { tc := tst @@ -829,6 +998,26 @@ func mustContextProto(t *testing.T, pb proto.Message) Activation { return ctx } +type returnTypeValidator struct { + returnType *Type +} + +func (returnTypeValidator) Name() string { + return "cel.validators.return_type" +} + +func (v returnTypeValidator) Validate(_ *Env, c ValidatorConfig, a *ast.AST, iss *Issues) { + if a.GetType(a.Expr().ID()) != v.returnType { + iss.ReportErrorAtID(a.Expr().ID(), + "unsupported return type: %s, want %s", + a.GetType(a.Expr().ID()), v.returnType.TypeName()) + } +} + +func (v returnTypeValidator) ToConfig() *env.Validator { + return env.NewValidator(v.Name()).SetConfig(map[string]any{"type_name": v.returnType.TypeName()}) +} + type customLegacyProvider struct { provider ref.TypeProvider } diff --git a/cel/options.go b/cel/options.go index 06d37049d..33a6c5e76 100644 --- a/cel/options.go +++ b/cel/options.go @@ -73,6 +73,26 @@ const ( featureIdentEscapeSyntax ) +var featureIDsToNames = map[int]string{ + featureEnableMacroCallTracking: "cel.feature.macro_call_tracking", + featureCrossTypeNumericComparisons: "cel.feature.cross_type_numeric_comparisons", + featureIdentEscapeSyntax: "cel.feature.backtick_escape_syntax", +} + +func featureNameByID(id int) (string, bool) { + name, found := featureIDsToNames[id] + return name, found +} + +func featureIDByName(name string) (int, bool) { + for id, n := range featureIDsToNames { + if n == name { + return id, true + } + } + return 0, false +} + // EnvOption is a functional interface for configuring the environment. type EnvOption func(e *Env) (*Env, error) @@ -456,28 +476,27 @@ type ConfigOptionFactory func(any) (EnvOption, bool) // as the type provider configured at the time when the config is processed is the one used to derive // type references from the configuration. func FromConfig(config *env.Config, optFactories ...ConfigOptionFactory) EnvOption { - return func(env *Env) (*Env, error) { + return func(e *Env) (*Env, error) { if err := config.Validate(); err != nil { return nil, err } - opts, err := configToEnvOptions(config, env.CELTypeProvider(), optFactories) + opts, err := configToEnvOptions(config, e.CELTypeProvider(), optFactories) if err != nil { return nil, err } for _, o := range opts { - env, err = o(env) + e, err = o(e) if err != nil { return nil, err } } - return env, nil + return e, nil } } // configToEnvOptions generates a set of EnvOption values (or error) based on a config, a type provider, // and an optional set of environment options. func configToEnvOptions(config *env.Config, provider types.Provider, optFactories []ConfigOptionFactory) ([]EnvOption, error) { - // note: ported from cel-go/policy/config.go envOpts := []EnvOption{} // Configure the standard lib subset. if config.StdLib != nil { @@ -519,6 +538,7 @@ func configToEnvOptions(config *env.Config, provider types.Provider, optFactorie envOpts = append(envOpts, DeclareContextProto(pbMsg.ProtoReflect().Descriptor())) } + // Configure variables if len(config.Variables) != 0 { vars := make([]*decls.VariableDecl, 0, len(config.Variables)) for _, v := range config.Variables { @@ -530,6 +550,8 @@ func configToEnvOptions(config *env.Config, provider types.Provider, optFactorie } envOpts = append(envOpts, VariableDecls(vars...)) } + + // Configure functions if len(config.Functions) != 0 { funcs := make([]*decls.FunctionDecl, 0, len(config.Functions)) for _, f := range config.Functions { @@ -541,22 +563,62 @@ func configToEnvOptions(config *env.Config, provider types.Provider, optFactorie } envOpts = append(envOpts, FunctionDecls(funcs...)) } - for _, e := range config.Extensions { - extHandled := false - for _, optFac := range optFactories { - if opt, useOption := optFac(e); useOption { - envOpts = append(envOpts, opt) - extHandled = true - break - } + + // Configure features + for _, feat := range config.Features { + // Note, if a feature is not found, it is skipped as it is possible the feature + // is not intended to be supported publicly. In the future, a refinement of + // to this strategy to report unrecognized features and validators should probably + // be covered as a standard ConfigOptionFactory + if id, found := featureIDByName(feat.Name); found { + envOpts = append(envOpts, features(id, feat.Enabled)) } - if !extHandled { - return nil, fmt.Errorf("unrecognized extension: %s", e.Name) + } + + // Configure validators + for _, val := range config.Validators { + if fac, found := astValidatorFactories[val.Name]; found { + envOpts = append(envOpts, func(e *Env) (*Env, error) { + validator, err := fac(val) + if err != nil { + return nil, fmt.Errorf("%w", err) + } + return ASTValidators(validator)(e) + }) + } else if opt, handled := handleExtendedConfigOption(val, optFactories); handled { + envOpts = append(envOpts, opt) + } + // we don't error when the validator isn't found as it may be part + // of an extension library and enabled implicitly. + } + + // Configure extensions + for _, ext := range config.Extensions { + // version number has been validated by the call to `Validate` + ver, _ := ext.VersionNumber() + if ext.Name == "optional" { + envOpts = append(envOpts, OptionalTypes(OptionalTypesVersion(ver))) + } else { + opt, handled := handleExtendedConfigOption(ext, optFactories) + if !handled { + return nil, fmt.Errorf("unrecognized extension: %s", ext.Name) + } + envOpts = append(envOpts, opt) } } + return envOpts, nil } +func handleExtendedConfigOption(conf any, optFactories []ConfigOptionFactory) (EnvOption, bool) { + for _, optFac := range optFactories { + if opt, useOption := optFac(conf); useOption { + return opt, true + } + } + return nil, false +} + // EvalOption indicates an evaluation option that may affect the evaluation behavior or information // in the output result. type EvalOption int diff --git a/cel/validator.go b/cel/validator.go index b50c67452..5f06b2dd5 100644 --- a/cel/validator.go +++ b/cel/validator.go @@ -20,11 +20,16 @@ import ( "regexp" "github.com/google/cel-go/common/ast" + "github.com/google/cel-go/common/env" "github.com/google/cel-go/common/overloads" ) const ( - homogeneousValidatorName = "cel.lib.std.validate.types.homogeneous" + durationValidatorName = "cel.validator.duration" + regexValidatorName = "cel.validator.matches" + timestampValidatorName = "cel.validator.timestamp" + homogeneousValidatorName = "cel.validator.homogeneous_literals" + nestingLimitValidatorName = "cel.validator.comprehension_nesting_limit" // HomogeneousAggregateLiteralExemptFunctions is the ValidatorConfig key used to configure // the set of function names which are exempt from homogeneous type checks. The expected type @@ -36,6 +41,35 @@ const ( HomogeneousAggregateLiteralExemptFunctions = homogeneousValidatorName + ".exempt" ) +var ( + astValidatorFactories = map[string]ASTValidatorFactory{ + nestingLimitValidatorName: func(val *env.Validator) (ASTValidator, error) { + if limit, found := val.ConfigValue("limit"); found { + if val, isInt := limit.(int); isInt { + return ValidateComprehensionNestingLimit(val), nil + } + return nil, fmt.Errorf("invalid validator: %s unsupported limit type: %v", nestingLimitValidatorName, limit) + } + return nil, fmt.Errorf("invalid validator: %s missing limit", nestingLimitValidatorName) + }, + durationValidatorName: func(*env.Validator) (ASTValidator, error) { + return ValidateDurationLiterals(), nil + }, + regexValidatorName: func(*env.Validator) (ASTValidator, error) { + return ValidateRegexLiterals(), nil + }, + timestampValidatorName: func(*env.Validator) (ASTValidator, error) { + return ValidateTimestampLiterals(), nil + }, + homogeneousValidatorName: func(*env.Validator) (ASTValidator, error) { + return ValidateHomogeneousAggregateLiterals(), nil + }, + } +) + +// ASTValidatorFactory creates an ASTValidator as configured by the input map +type ASTValidatorFactory func(*env.Validator) (ASTValidator, error) + // ASTValidators configures a set of ASTValidator instances into the target environment. // // Validators are applied in the order in which the are specified and are treated as singletons. @@ -70,6 +104,18 @@ type ASTValidator interface { Validate(*Env, ValidatorConfig, *ast.AST, *Issues) } +// ConfigurableASTValidator supports conversion of an object to an `env.Validator` instance used for +// YAML serialization. +type ConfigurableASTValidator interface { + // ToConfig converts the internal configuration of an ASTValidator into an env.Validator instance + // which minimally must include the validator name, but may also include a map[string]any config + // object to be serialized to YAML. The string keys represent the configuration parameter name, + // and the any value must mirror the internally supported type associated with the config key. + // + // Note: only primitive CEL types are supported by CEL validators at this time. + ToConfig() *env.Validator +} + // ValidatorConfig provides an accessor method for querying validator configuration state. type ValidatorConfig interface { GetOrDefault(name string, value any) any @@ -196,7 +242,12 @@ type formatValidator struct { // Name returns the unique name of this function format validator. func (v formatValidator) Name() string { - return fmt.Sprintf("cel.lib.std.validate.functions.%s", v.funcName) + return fmt.Sprintf("cel.validator.%s", v.funcName) +} + +// ToConfig converts the ASTValidator to an env.Validator specifying the validator name. +func (v formatValidator) ToConfig() *env.Validator { + return env.NewValidator(v.Name()) } // Validate searches the AST for uses of a given function name with a constant argument and performs a check @@ -242,6 +293,11 @@ func (homogeneousAggregateLiteralValidator) Name() string { return homogeneousValidatorName } +// ToConfig converts the ASTValidator to an env.Validator specifying the validator name. +func (v homogeneousAggregateLiteralValidator) ToConfig() *env.Validator { + return env.NewValidator(v.Name()) +} + // Validate validates that all lists and map literals have homogeneous types, i.e. don't contain dyn types. // // This validator makes an exception for list and map literals which occur at any level of nesting within @@ -336,10 +392,18 @@ type nestingLimitValidator struct { limit int } +// Name returns the name of the nesting limit validator. func (v nestingLimitValidator) Name() string { - return "cel.lib.std.validate.comprehension_nesting_limit" + return nestingLimitValidatorName +} + +// ToConfig converts the ASTValidator to an env.Validator specifying the validator name and the nesting limit +// as an integer value: {"limit": int} +func (v nestingLimitValidator) ToConfig() *env.Validator { + return env.NewValidator(v.Name()).SetConfig(map[string]any{"limit": v.limit}) } +// Validate implements the ASTValidator interface method. func (v nestingLimitValidator) Validate(e *Env, _ ValidatorConfig, a *ast.AST, iss *Issues) { root := ast.NavigateAST(a) comprehensions := ast.MatchDescendants(root, ast.KindMatcher(ast.ComprehensionKind)) diff --git a/common/env/env.go b/common/env/env.go index 10c7b1e72..aa7b066c5 100644 --- a/common/env/env.go +++ b/common/env/env.go @@ -48,6 +48,8 @@ type Config struct { ContextVariable *ContextVariable `yaml:"context_variable,omitempty"` Variables []*Variable `yaml:"variables,omitempty"` Functions []*Function `yaml:"functions,omitempty"` + Validators []*Validator `yaml:"validators,omitempty"` + Features []*Feature `yaml:"features,omitempty"` } // Validate validates the whole configuration is well-formed. @@ -85,6 +87,16 @@ func (c *Config) Validate() error { errs = append(errs, err) } } + for _, feat := range c.Features { + if err := feat.Validate(); err != nil { + errs = append(errs, err) + } + } + for _, val := range c.Validators { + if err := val.Validate(); err != nil { + errs = append(errs, err) + } + } return errors.Join(errs...) } @@ -172,6 +184,18 @@ func (c *Config) AddExtensions(exts ...*Extension) *Config { return c } +// AddValidators appends one or more validators to the config. +func (c *Config) AddValidators(vals ...*Validator) *Config { + c.Validators = append(c.Validators, vals...) + return c +} + +// AddFeatures appends one or more features to the config. +func (c *Config) AddFeatures(feats ...*Feature) *Config { + c.Features = append(c.Features, feats...) + return c +} + // NewImport returns a serializable import value from the qualified type name. func NewImport(name string) *Import { return &Import{Name: name} @@ -428,12 +452,12 @@ type Extension struct { // Validate validates the extension configuration is well-formed. func (e *Extension) Validate() error { - _, err := e.GetVersion() + _, err := e.VersionNumber() return err } -// GetVersion returns the parsed version string, or an error if the version cannot be parsed. -func (e *Extension) GetVersion() (uint32, error) { +// VersionNumber returns the parsed version string, or an error if the version cannot be parsed. +func (e *Extension) VersionNumber() (uint32, error) { if e == nil { return 0, fmt.Errorf("invalid extension: nil") } @@ -625,6 +649,68 @@ func (lib *LibrarySubset) AddExcludedFunctions(funcs ...*Function) *LibrarySubse return lib } +// NewValidator returns a named Validator instance. +func NewValidator(name string) *Validator { + return &Validator{Name: name} +} + +// Validator represents a named validator with an optional map-based configuration object. +// +// Note: the map-keys must directly correspond to the internal representation of the original +// validator, and should only use primitive scalar types as values at this time. +type Validator struct { + Name string `yaml:"name"` + Config map[string]any `yaml:"config,omitempty"` +} + +// Validate validates the configuration of the validator object. +func (v *Validator) Validate() error { + if v == nil { + return errors.New("invalid validator: nil") + } + if v.Name == "" { + return errors.New("invalid validator: missing name") + } + return nil +} + +// SetConfig sets the set of map key-value pairs associated with this validator's configuration. +func (v *Validator) SetConfig(config map[string]any) *Validator { + v.Config = config + return v +} + +// ConfigValue retrieves the value associated with the config key name, if one exists. +func (v *Validator) ConfigValue(name string) (any, bool) { + if v == nil { + return nil, false + } + value, found := v.Config[name] + return value, found +} + +// NewFeature creates a new feature flag with a boolean enablement flag. +func NewFeature(name string, enabled bool) *Feature { + return &Feature{Name: name, Enabled: enabled} +} + +// Feature represents a named boolean feature flag supported by CEL. +type Feature struct { + Name string `yaml:"name"` + Enabled bool `yaml:"enabled"` +} + +// Validate validates whether the feature is well-configured. +func (feat *Feature) Validate() error { + if feat == nil { + return errors.New("invalid feature: nil") + } + if feat.Name == "" { + return errors.New("invalid feature: missing name") + } + return nil +} + // NewTypeDesc describes a simple or complex type with parameters. func NewTypeDesc(typeName string, params ...*TypeDesc) *TypeDesc { return &TypeDesc{TypeName: typeName, Params: params} diff --git a/common/env/env_test.go b/common/env/env_test.go index 1a18038a8..2e2cd1d86 100644 --- a/common/env/env_test.go +++ b/common/env/env_test.go @@ -86,6 +86,14 @@ func TestConfig(t *testing.T) { NewTypeDesc("list", NewTypeParam("T")), nil, NewTypeDesc("bool")), ), + ).AddFeatures( + NewFeature("cel.feature.macro_call_tracking", true), + ).AddValidators( + NewValidator("cel.validator.duration"), + NewValidator("cel.validator.matches"), + NewValidator("cel.validator.timestamp"), + NewValidator("cel.validator.nesting_comprehension_limit"). + SetConfig(map[string]any{"limit": 2}), ), }, { @@ -168,6 +176,32 @@ func TestConfig(t *testing.T) { } } } + if len(got.Features) != len(tc.want.Features) { + t.Errorf("Features count got %d, wanted %d", len(got.Features), len(tc.want.Features)) + } else { + for i, f := range got.Features { + wf := tc.want.Features[i] + if f.Name != wf.Name { + t.Errorf("Features[%d] got name %s, wanted %s", i, f.Name, wf.Name) + } + if f.Enabled != wf.Enabled { + t.Errorf("Features[%d] got enabled %t, wanted %t", i, f.Enabled, wf.Enabled) + } + } + } + if len(got.Validators) != len(tc.want.Validators) { + t.Errorf("Validators count got %d, wanted %d", len(got.Validators), len(tc.want.Validators)) + } else { + for i, f := range got.Validators { + wf := tc.want.Validators[i] + if f.Name != wf.Name { + t.Errorf("Validators[%d] got name %s, wanted %s", i, f.Name, wf.Name) + } + if !reflect.DeepEqual(f.Config, wf.Config) { + t.Errorf("Validators[%d] got enabled %v, wanted %v", i, f.Config, wf.Config) + } + } + } }) } } @@ -206,11 +240,28 @@ func TestConfigValidateErrors(t *testing.T) { in: NewConfig("invalid variable").AddVariables(NewVariable("", nil)), want: errors.New("invalid variable"), }, + { + name: "colliding context variable", + in: NewConfig("colliding context variable"). + SetContextVariable(NewContextVariable("msg.type.Name")). + AddVariables(NewVariable("local", NewTypeDesc("string"))), + want: errors.New("invalid config"), + }, { name: "invalid function", in: NewConfig("invalid function").AddFunctions(NewFunction("", nil)), want: errors.New("invalid function"), }, + { + name: "invalid feature", + in: NewConfig("invalid feature").AddFeatures(NewFeature("", false)), + want: errors.New("invalid feature"), + }, + { + name: "invalid validator", + in: NewConfig("invalid validator").AddValidators(NewValidator("")), + want: errors.New("invalid validator"), + }, } for _, tst := range tests { @@ -1116,7 +1167,7 @@ func TestExtensionGetVersion(t *testing.T) { for _, tst := range tests { tc := tst t.Run(tc.name, func(t *testing.T) { - ver, err := tc.ext.GetVersion() + ver, err := tc.ext.VersionNumber() if err != nil { wantErr, ok := tc.want.(error) if !ok { @@ -1134,6 +1185,96 @@ func TestExtensionGetVersion(t *testing.T) { } } +func TestValidatorValidate(t *testing.T) { + tests := []struct { + name string + v *Validator + want error + }{ + { + name: "nil validator", + v: nil, + want: errors.New("invalid validator: nil"), + }, + { + name: "empty validator", + v: NewValidator(""), + want: errors.New("missing name"), + }, + } + + for _, tst := range tests { + tc := tst + t.Run(tc.name, func(t *testing.T) { + err := tc.v.Validate() + if err == nil && tc.want == nil { + return + } + if err == nil && tc.want != nil { + t.Fatalf("v.Validate() got valid, wanted error %v", tc.want) + } + if err != nil && tc.want == nil { + t.Fatalf("v.Validate() got error %v, wanted nil error", err) + } + if !strings.Contains(err.Error(), tc.want.Error()) { + t.Errorf("v.Validate() got error %v, wanted %v", err, tc.want) + } + }) + } +} + +func TestValidatorConfigValue(t *testing.T) { + var v *Validator + if _, found := v.ConfigValue("limit"); found { + t.Error("v.ConfigValue() got value from nil validator") + } + v = NewValidator("validator").SetConfig(map[string]any{"limit": 2}) + if _, found := v.ConfigValue("absent"); found { + t.Error("v.ConfigValue() found absent key") + } + if val, found := v.ConfigValue("limit"); !found || val != 2 { + t.Errorf("v.ConfigValue() got %v, %t -- wanted 2, true", val, found) + } +} + +func TestFeatureValidate(t *testing.T) { + tests := []struct { + name string + f *Feature + want error + }{ + { + name: "nil feature", + f: nil, + want: errors.New("invalid feature: nil"), + }, + { + name: "empty feature", + f: NewFeature("", true), + want: errors.New("missing name"), + }, + } + + for _, tst := range tests { + tc := tst + t.Run(tc.name, func(t *testing.T) { + err := tc.f.Validate() + if err == nil && tc.want == nil { + return + } + if err == nil && tc.want != nil { + t.Fatalf("f.Validate() got valid, wanted error %v", tc.want) + } + if err != nil && tc.want == nil { + t.Fatalf("f.Validate() got error %v, wanted nil error", err) + } + if !strings.Contains(err.Error(), tc.want.Error()) { + t.Errorf("f.Validate() got error %v, wanted %v", err, tc.want) + } + }) + } +} + func mustNewFunction(t *testing.T, name string, opts ...decls.FunctionOpt) *decls.FunctionDecl { t.Helper() fn, err := decls.NewFunction(name, opts...) diff --git a/common/env/testdata/extended_env.yaml b/common/env/testdata/extended_env.yaml index 041002e75..808e0a89e 100644 --- a/common/env/testdata/extended_env.yaml +++ b/common/env/testdata/extended_env.yaml @@ -38,3 +38,13 @@ functions: is_type_param: true return: type_name: "bool" +validators: + - name: cel.validator.duration + - name: cel.validator.matches + - name: cel.validator.timestamp + - name: cel.validator.nesting_comprehension_limit + config: + limit: 2 +features: + - name: cel.feature.macro_call_tracking + enabled: true diff --git a/ext/bindings.go b/ext/bindings.go index 81dae50f2..63942b85c 100644 --- a/ext/bindings.go +++ b/ext/bindings.go @@ -149,7 +149,7 @@ type blockValidationExemption struct{} // Name returns the name of the validator. func (blockValidationExemption) Name() string { - return "cel.lib.ext.validate.functions.cel.block" + return "cel.validator.cel_block" } // Configure implements the ASTValidatorConfigurer interface and augments the list of functions to skip diff --git a/ext/formatting.go b/ext/formatting.go index 932d562ec..aa334ccd9 100644 --- a/ext/formatting.go +++ b/ext/formatting.go @@ -407,7 +407,7 @@ type stringFormatValidator struct{} // Name returns the name of the validator. func (stringFormatValidator) Name() string { - return "cel.lib.ext.validate.functions.string.format" + return "cel.validator.string_format" } // Configure implements the ASTValidatorConfigurer interface and augments the list of functions to skip diff --git a/policy/config.go b/policy/config.go index 0bd7902c1..12fbe44a7 100644 --- a/policy/config.go +++ b/policy/config.go @@ -42,7 +42,7 @@ func extensionOptionFactory(configElement any) (cel.EnvOption, bool) { return nil, false } // If the version is 'latest', set the version value to the max uint. - ver, err := ext.GetVersion() + ver, err := ext.VersionNumber() if err != nil { return func(*cel.Env) (*cel.Env, error) { return nil, fmt.Errorf("invalid extension version: %s - %s", ext.Name, ext.Version) @@ -67,9 +67,6 @@ var extFactories = map[string]extensionFactory{ "math": func(version uint32) cel.EnvOption { return ext.Math(ext.MathVersion(version)) }, - "optional": func(version uint32) cel.EnvOption { - return cel.OptionalTypes(cel.OptionalTypesVersion(version)) - }, "protos": func(version uint32) cel.EnvOption { return ext.Protos(ext.ProtosVersion(version)) }, From fad0c1b7493e76f23b647d3429d1320f8c6bcddb Mon Sep 17 00:00:00 2001 From: Justin King Date: Fri, 21 Feb 2025 10:12:24 -0800 Subject: [PATCH 14/46] Use remote caching for Cloud Build with Bazel (#1134) * Use remote caching for Cloud Build with Bazel Signed-off-by: Justin King * Also bump machineType to E2_HIGHCPU_32 Signed-off-by: Justin King --------- Signed-off-by: Justin King --- cloudbuild.yaml | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/cloudbuild.yaml b/cloudbuild.yaml index 9cfbc223f..4fd87abf9 100644 --- a/cloudbuild.yaml +++ b/cloudbuild.yaml @@ -26,9 +26,14 @@ steps: args: ['scripts/verify-vendor.sh'] - name: 'gcr.io/cloud-builders/bazel' entrypoint: bazel - args: ['test', '--test_output=errors', '...'] + args: + - 'test' + - '--test_output=errors' + - '--remote_cache=https://storage.googleapis.com/cel-go-remote-cache' + - '--google_default_credentials' + - '...' id: bazel-test waitFor: ['-'] timeout: 10m options: - machineType: 'N1_HIGHCPU_8' + machineType: 'E2_HIGHCPU_32' From 9855c701f050e8726f5818358ef2ee5946f9dcba Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Tue, 25 Feb 2025 14:15:21 -0800 Subject: [PATCH 15/46] Support for splitting nested branching operators within policies (#1136) * Support for splitting nested branching operators within policies * Introduce an ast.Heights() helper * Updated tests and expanded flattening to all calls * Added test case for comprehension pruning during unnest --- common/ast/ast.go | 78 +++++++++++++ common/ast/ast_test.go | 25 +++++ common/ast/navigable.go | 7 +- policy/compiler.go | 3 +- policy/compiler_test.go | 175 ++++++++++++++++++++--------- policy/composer.go | 168 +++++++++++++++++++++++++-- policy/helper_test.go | 150 ++++++++++++++++++++++++- policy/testdata/unnest/config.yaml | 20 ++++ policy/testdata/unnest/policy.yaml | 32 ++++++ policy/testdata/unnest/tests.yaml | 50 +++++++++ 10 files changed, 644 insertions(+), 64 deletions(-) create mode 100644 policy/testdata/unnest/config.yaml create mode 100644 policy/testdata/unnest/policy.yaml create mode 100644 policy/testdata/unnest/tests.yaml diff --git a/common/ast/ast.go b/common/ast/ast.go index b807669d4..62c09cfc6 100644 --- a/common/ast/ast.go +++ b/common/ast/ast.go @@ -160,6 +160,13 @@ func MaxID(a *AST) int64 { return visitor.maxID + 1 } +// Heights computes the heights of all AST expressions and returns a map from expression id to height. +func Heights(a *AST) map[int64]int { + visitor := make(heightVisitor) + PostOrderVisit(a.Expr(), visitor) + return visitor +} + // NewSourceInfo creates a simple SourceInfo object from an input common.Source value. func NewSourceInfo(src common.Source) *SourceInfo { var lineOffsets []int32 @@ -455,3 +462,74 @@ func (v *maxIDVisitor) VisitEntryExpr(e EntryExpr) { v.maxID = e.ID() } } + +type heightVisitor map[int64]int + +// VisitExpr computes the height of a given node as the max height of its children plus one. +// +// Identifiers and literals are treated as having a height of zero. +func (hv heightVisitor) VisitExpr(e Expr) { + // default includes IdentKind, LiteralKind + hv[e.ID()] = 0 + switch e.Kind() { + case SelectKind: + hv[e.ID()] = 1 + hv[e.AsSelect().Operand().ID()] + case CallKind: + c := e.AsCall() + height := hv.maxHeight(c.Args()...) + if c.IsMemberFunction() { + tHeight := hv[c.Target().ID()] + if tHeight > height { + height = tHeight + } + } + hv[e.ID()] = 1 + height + case ListKind: + l := e.AsList() + hv[e.ID()] = 1 + hv.maxHeight(l.Elements()...) + case MapKind: + m := e.AsMap() + hv[e.ID()] = 1 + hv.maxEntryHeight(m.Entries()...) + case StructKind: + s := e.AsStruct() + hv[e.ID()] = 1 + hv.maxEntryHeight(s.Fields()...) + case ComprehensionKind: + comp := e.AsComprehension() + hv[e.ID()] = 1 + hv.maxHeight(comp.IterRange(), comp.AccuInit(), comp.LoopCondition(), comp.LoopStep(), comp.Result()) + } +} + +// VisitEntryExpr computes the max height of a map or struct entry and associates the height with the entry id. +func (hv heightVisitor) VisitEntryExpr(e EntryExpr) { + hv[e.ID()] = 0 + switch e.Kind() { + case MapEntryKind: + me := e.AsMapEntry() + hv[e.ID()] = hv.maxHeight(me.Value(), me.Key()) + case StructFieldKind: + sf := e.AsStructField() + hv[e.ID()] = hv[sf.Value().ID()] + } +} + +func (hv heightVisitor) maxHeight(exprs ...Expr) int { + max := 0 + for _, e := range exprs { + h := hv[e.ID()] + if h > max { + max = h + } + } + return max +} + +func (hv heightVisitor) maxEntryHeight(entries ...EntryExpr) int { + max := 0 + for _, e := range entries { + h := hv[e.ID()] + if h > max { + max = h + } + } + return max +} diff --git a/common/ast/ast_test.go b/common/ast/ast_test.go index a4a4a57cf..7a1c6a141 100644 --- a/common/ast/ast_test.go +++ b/common/ast/ast_test.go @@ -339,6 +339,31 @@ func TestMaxID(t *testing.T) { } } +func TestHeights(t *testing.T) { + tests := []struct { + expr string + height int + }{ + {`'a' == 'b'`, 1}, + {`'a'.size()`, 1}, + {`[1, 2].size()`, 2}, + {`size('a')`, 1}, + {`has({'a': 1}.a)`, 2}, + {`{'a': 1}`, 1}, + {`{'a': 1}['a']`, 2}, + {`[1, 2, 3].exists(i, i % 2 == 1)`, 4}, + {`google.expr.proto3.test.TestAllTypes{}`, 1}, + {`google.expr.proto3.test.TestAllTypes{repeated_int32: [1, 2]}`, 2}, + } + for _, tst := range tests { + checked := mustTypeCheck(t, tst.expr) + maxHeight := ast.Heights(checked)[checked.Expr().ID()] + if maxHeight != tst.height { + t.Errorf("ast.Heights(%q) got max height %d, wanted %d", tst.expr, maxHeight, tst.height) + } + } +} + func mockRelativeSource(t testing.TB, text string, lineOffsets []int32, baseLocation common.Location) common.Source { t.Helper() return &mockSource{ diff --git a/common/ast/navigable.go b/common/ast/navigable.go index d7a90fb7c..13e5777b5 100644 --- a/common/ast/navigable.go +++ b/common/ast/navigable.go @@ -237,8 +237,13 @@ func visit(expr Expr, visitor Visitor, order visitOrder, depth, maxDepth int) { case StructKind: s := expr.AsStruct() for _, f := range s.Fields() { - visitor.VisitEntryExpr(f) + if order == preOrder { + visitor.VisitEntryExpr(f) + } visit(f.AsStructField().Value(), visitor, order, depth+1, maxDepth) + if order == postOrder { + visitor.VisitEntryExpr(f) + } } } if order == postOrder { diff --git a/policy/compiler.go b/policy/compiler.go index 93505ff98..bdf495a98 100644 --- a/policy/compiler.go +++ b/policy/compiler.go @@ -198,7 +198,8 @@ func Compile(env *cel.Env, p *Policy, opts ...CompilerOption) (*cel.Ast, *cel.Is if iss.Err() != nil { return nil, iss } - composer := NewRuleComposer(env, p) + // An error cannot happen when composing without supplying options + composer, _ := NewRuleComposer(env) return composer.Compose(rule) } diff --git a/policy/compiler_test.go b/policy/compiler_test.go index cc3e80f49..b318d2d6d 100644 --- a/policy/compiler_test.go +++ b/policy/compiler_test.go @@ -31,8 +31,55 @@ import ( func TestCompile(t *testing.T) { for _, tst := range policyTests { - t.Run(tst.name, func(t *testing.T) { - r := newRunner(t, tst.name, tst.expr, tst.parseOpts, tst.envOpts...) + tc := tst + t.Run(tc.name, func(t *testing.T) { + r := newRunner(tc.name, tc.expr, tc.parseOpts) + env, ast, iss := r.compile(t, tc.envOpts, []CompilerOption{}) + if iss.Err() != nil { + t.Fatalf("Compile(%s) failed: %v", r.name, iss.Err()) + } + r.setup(t, env, ast) + r.run(t) + }) + } +} + +func TestRuleComposerError(t *testing.T) { + env, err := cel.NewEnv() + if err != nil { + t.Fatalf("NewEnv() failed: %v", err) + } + _, err = NewRuleComposer(env, ExpressionUnnestHeight(-1)) + if err == nil || !strings.Contains(err.Error(), "invalid unnest") { + t.Errorf("NewRuleComposer() got %v, wanted 'invalid unnest'", err) + } +} + +func TestRuleComposerUnnest(t *testing.T) { + for _, tst := range composerUnnestTests { + tc := tst + t.Run(tc.name, func(t *testing.T) { + r := newRunner(tc.name, tc.expr, []ParserOption{}) + env, rule, iss := r.compileRule(t) + if iss.Err() != nil { + t.Fatalf("CompileRule() failed: %v", iss.Err()) + } + rc, err := NewRuleComposer(env, tc.composerOpts...) + if err != nil { + t.Fatalf("NewRuleComposer() failed: %v", err) + } + ast, iss := rc.Compose(rule) + if iss.Err() != nil { + t.Fatalf("Compose(rule) failed: %v", iss.Err()) + } + unparsed, err := cel.AstToString(ast) + if err != nil { + t.Fatalf("cel.AstToString() failed: %v", err) + } + if normalize(unparsed) != normalize(tc.composed) { + t.Errorf("cel.AstToString() got %s, wanted %s", unparsed, tc.composed) + } + r.setup(t, env, ast) r.run(t) }) } @@ -40,7 +87,8 @@ func TestCompile(t *testing.T) { func TestCompileError(t *testing.T) { for _, tst := range policyErrorTests { - _, _, iss := compile(t, tst.name, []ParserOption{}, []cel.EnvOption{}, tst.compilerOpts) + policy := parsePolicy(t, tst.name, []ParserOption{}) + _, _, iss := compile(t, tst.name, policy, []cel.EnvOption{}, tst.compilerOpts) if iss.Err() == nil { t.Fatalf("compile(%s) did not error, wanted %s", tst.name, tst.err) } @@ -98,7 +146,8 @@ func TestMaxNestedExpressions_Error(t *testing.T) { wantError := `ERROR: testdata/required_labels/policy.yaml:15:8: error configuring compiler option: nested expression limit must be non-negative, non-zero value: -1 | name: "required_labels" | .......^` - _, _, iss := compile(t, policyName, []ParserOption{}, []cel.EnvOption{}, []CompilerOption{MaxNestedExpressions(-1)}) + policy := parsePolicy(t, policyName, []ParserOption{}) + _, _, iss := compile(t, policyName, policy, []cel.EnvOption{}, []CompilerOption{MaxNestedExpressions(-1)}) if iss.Err() == nil { t.Fatalf("compile(%s) did not error, wanted %s", policyName, wantError) } @@ -109,55 +158,40 @@ func TestMaxNestedExpressions_Error(t *testing.T) { func BenchmarkCompile(b *testing.B) { for _, tst := range policyTests { - r := newRunner(b, tst.name, tst.expr, tst.parseOpts, tst.envOpts...) + r := newRunner(tst.name, tst.expr, tst.parseOpts) + env, ast, iss := r.compile(b, tst.envOpts, []CompilerOption{}) + if iss.Err() != nil { + b.Fatalf("Compile() failed: %v", iss.Err()) + } + r.setup(b, env, ast) r.bench(b) } } -func newRunner(t testing.TB, name, expr string, parseOpts []ParserOption, opts ...cel.EnvOption) *runner { - r := &runner{ +func newRunner(name, expr string, parseOpts []ParserOption, opts ...cel.EnvOption) *runner { + return &runner{ name: name, - envOpts: opts, parseOpts: parseOpts, expr: expr} - r.setup(t) - return r } type runner struct { - name string - envOpts []cel.EnvOption - parseOpts []ParserOption - compilerOpts []CompilerOption - env *cel.Env - expr string - prg cel.Program + name string + parseOpts []ParserOption + env *cel.Env + expr string + prg cel.Program } -func mustCompileExpr(t testing.TB, env *cel.Env, expr string) *cel.Ast { - t.Helper() - out, iss := env.Compile(expr) - if iss.Err() != nil { - t.Fatalf("env.Compile(%s) failed: %v", expr, iss.Err()) - } - return out +func (r *runner) compile(t testing.TB, envOpts []cel.EnvOption, compilerOpts []CompilerOption) (*cel.Env, *cel.Ast, *cel.Issues) { + policy := parsePolicy(t, r.name, r.parseOpts) + return compile(t, r.name, policy, envOpts, compilerOpts) } -func compile(t testing.TB, name string, parseOpts []ParserOption, envOpts []cel.EnvOption, compilerOpts []CompilerOption) (*cel.Env, *cel.Ast, *cel.Issues) { +func (r *runner) compileRule(t testing.TB) (*cel.Env, *CompiledRule, *cel.Issues) { t.Helper() - config := readPolicyConfig(t, fmt.Sprintf("testdata/%s/config.yaml", name)) - srcFile := readPolicy(t, fmt.Sprintf("testdata/%s/policy.yaml", name)) - parser, err := NewParser(parseOpts...) - if err != nil { - t.Fatalf("NewParser() failed: %v", err) - } - policy, iss := parser.Parse(srcFile) - if iss.Err() != nil { - t.Fatalf("Parse() failed: %v", iss.Err()) - } - if policy.name.Value != name { - t.Errorf("policy name is %v, wanted %q", policy.name, name) - } + config := readPolicyConfig(t, fmt.Sprintf("testdata/%s/config.yaml", r.name)) + policy := parsePolicy(t, r.name, r.parseOpts) env, err := cel.NewCustomEnv( cel.OptionalTypes(), cel.EnableMacroCallTracking(), @@ -166,26 +200,17 @@ func compile(t testing.TB, name string, parseOpts []ParserOption, envOpts []cel. if err != nil { t.Fatalf("cel.NewEnv() failed: %v", err) } - // Configure any custom environment options. - env, err = env.Extend(envOpts...) - if err != nil { - t.Fatalf("env.Extend() with env options %v, failed: %v", config, err) - } // Configure declarations env, err = env.Extend(FromConfig(config)) if err != nil { t.Fatalf("env.Extend() with config options %v, failed: %v", config, err) } - ast, iss := Compile(env, policy, compilerOpts...) - return env, ast, iss + rule, iss := CompileRule(env, policy) + return env, rule, iss } -func (r *runner) setup(t testing.TB) { +func (r *runner) setup(t testing.TB, env *cel.Env, ast *cel.Ast) { t.Helper() - env, ast, iss := compile(t, r.name, r.parseOpts, r.envOpts, r.compilerOpts) - if iss.Err() != nil { - t.Fatalf("Compile(%s) failed: %v", r.name, iss.Err()) - } pExpr, err := cel.AstToString(ast) if err != nil { t.Fatalf("cel.AstToString() failed: %v", err) @@ -323,6 +348,56 @@ func (r *runner) eval(t testing.TB, expr string) ref.Val { return out } +func mustCompileExpr(t testing.TB, env *cel.Env, expr string) *cel.Ast { + t.Helper() + out, iss := env.Compile(expr) + if iss.Err() != nil { + t.Fatalf("env.Compile(%s) failed: %v", expr, iss.Err()) + } + return out +} + +func parsePolicy(t testing.TB, name string, parseOpts []ParserOption) *Policy { + t.Helper() + srcFile := readPolicy(t, fmt.Sprintf("testdata/%s/policy.yaml", name)) + parser, err := NewParser(parseOpts...) + if err != nil { + t.Fatalf("NewParser() failed: %v", err) + } + policy, iss := parser.Parse(srcFile) + if iss.Err() != nil { + t.Fatalf("Parse() failed: %v", iss.Err()) + } + if policy.name.Value != name { + t.Errorf("policy name is %v, wanted %q", policy.name, name) + } + return policy +} + +func compile(t testing.TB, name string, policy *Policy, envOpts []cel.EnvOption, compilerOpts []CompilerOption) (*cel.Env, *cel.Ast, *cel.Issues) { + config := readPolicyConfig(t, fmt.Sprintf("testdata/%s/config.yaml", name)) + env, err := cel.NewCustomEnv( + cel.OptionalTypes(), + cel.EnableMacroCallTracking(), + cel.ExtendedValidations(), + ext.Bindings()) + if err != nil { + t.Fatalf("cel.NewEnv() failed: %v", err) + } + // Configure any custom environment options. + env, err = env.Extend(envOpts...) + if err != nil { + t.Fatalf("env.Extend() with env options %v, failed: %v", config, err) + } + // Configure declarations + env, err = env.Extend(FromConfig(config)) + if err != nil { + t.Fatalf("env.Extend() with config options %v, failed: %v", config, err) + } + ast, iss := Compile(env, policy, compilerOpts...) + return env, ast, iss +} + func normalize(s string) string { return strings.ReplaceAll( strings.ReplaceAll( diff --git a/policy/composer.go b/policy/composer.go index be326ded4..0b9be2a5c 100644 --- a/policy/composer.go +++ b/policy/composer.go @@ -15,7 +15,9 @@ package policy import ( + "cmp" "fmt" + "slices" "strings" "github.com/google/cel-go/cel" @@ -24,25 +26,58 @@ import ( "github.com/google/cel-go/common/types" ) +// ComposerOption is a functional option used to configure a RuleComposer +type ComposerOption func(*RuleComposer) (*RuleComposer, error) + +// ExpressionUnnestHeight determines the height at which nested expressions are split into local +// variables within the cel.@block declaration. +func ExpressionUnnestHeight(height int) ComposerOption { + return func(c *RuleComposer) (*RuleComposer, error) { + if height <= 0 { + return nil, fmt.Errorf("invalid unnest height: value must be positive: %d", height) + } + c.exprUnnestHeight = height + return c, nil + } +} + // NewRuleComposer creates a rule composer which stitches together rules within a policy into // a single CEL expression. -func NewRuleComposer(env *cel.Env, p *Policy) *RuleComposer { - return &RuleComposer{ +func NewRuleComposer(env *cel.Env, opts ...ComposerOption) (*RuleComposer, error) { + composer := &RuleComposer{ env: env, - p: p, + // set the default nesting height to something reasonable. + exprUnnestHeight: 25, + } + var err error + for _, opt := range opts { + composer, err = opt(composer) + if err != nil { + return nil, err + } } + return composer, nil } // RuleComposer optimizes a set of expressions into a single expression. type RuleComposer struct { env *cel.Env - p *Policy + + // exprUnnestHeight determines the height at which nested matches are split into + // index variables within a cel.@block index declaration when composing matches under + // the first-match semantic. + exprUnnestHeight int } // Compose stitches together a set of expressions within a CompiledRule into a single CEL ast. func (c *RuleComposer) Compose(r *CompiledRule) (*cel.Ast, *cel.Issues) { ruleRoot, _ := c.env.Compile("true") - opt := cel.NewStaticOptimizer(&ruleComposerImpl{rule: r, varIndices: []varIndex{}}) + opt := cel.NewStaticOptimizer( + &ruleComposerImpl{ + rule: r, + varIndices: []varIndex{}, + exprUnnestHeight: c.exprUnnestHeight, + }) return opt.Optimize(c.env, ruleRoot) } @@ -51,7 +86,7 @@ type varIndex struct { indexVar string localVar string expr ast.Expr - cv *CompiledVariable + celType *types.Type } type ruleComposerImpl struct { @@ -59,7 +94,7 @@ type ruleComposerImpl struct { nextVarIndex int varIndices []varIndex - maxNestedExpressionLimit int + exprUnnestHeight int } // Optimize implements an AST optimizer for CEL which composes an expression graph into a single @@ -68,17 +103,23 @@ func (opt *ruleComposerImpl) Optimize(ctx *cel.OptimizerContext, a *ast.AST) *as // The input to optimize is a dummy expression which is completely replaced according // to the configuration of the rule composition graph. ruleExpr := opt.optimizeRule(ctx, opt.rule) + // If the rule is deeply nested, it may need to be unnested. This process may generate + // additional variables that are included in the `sortedVariables` list. + ruleExpr = opt.maybeUnnestRule(ctx, ruleExpr) + + // Collect all variables associated with the rule expression. allVars := opt.sortedVariables() // If there were no variables, return the expression. if len(allVars) == 0 { return ctx.NewAST(ruleExpr) } - // Otherwise populate the block. + // Otherwise populate the cel.@block with the variable declarations and wrap the expression + // in the block. varExprs := make([]ast.Expr, len(allVars)) for i, vi := range allVars { varExprs[i] = vi.expr - err := ctx.ExtendEnv(cel.Variable(vi.indexVar, vi.cv.Declaration().Type())) + err := ctx.ExtendEnv(cel.Variable(vi.indexVar, vi.celType)) if err != nil { ctx.ReportErrorAtID(ruleExpr.ID(), "%s", err.Error()) } @@ -156,6 +197,57 @@ func (opt *ruleComposerImpl) rewriteVariableName(ctx *cel.OptimizerContext) ast. }) } +func (opt *ruleComposerImpl) maybeUnnestRule(ctx *cel.OptimizerContext, ruleExpr ast.Expr) ast.Expr { + // Split the expr into local variables based on expression height + ruleAST := ctx.NewAST(ruleExpr) + ruleNav := ast.NavigateAST(ruleAST) + // Unnest expressions are ordered from leaf to root via the ast.MatchDescendants call. + heights := ast.Heights(ruleAST) + unnestMap := map[int64]bool{} + unnestExprs := []ast.NavigableExpr{} + ast.MatchDescendants(ruleNav, func(e ast.NavigableExpr) bool { + // If the expression is a comprehension, then all unnest candidates captured previously that relate + // to the comprehension body should be removed from the list of candidate branches for unnesting. + if e.Kind() == ast.ComprehensionKind { + // This only removes branches from the map, but not from the list of branches. + removeIneligibleSubExprs(e, unnestMap) + return false + } + // Otherwise, if the expression is not a call, don't include it. + if e.Kind() != ast.CallKind { + return false + } + height := heights[e.ID()] + if height < opt.exprUnnestHeight { + return false + } + unnestMap[e.ID()] = true + unnestExprs = append(unnestExprs, e) + return true + }) + + slices.SortStableFunc(unnestExprs, func(a, b ast.NavigableExpr) int { + heightA := heights[a.ID()] + heightB := heights[b.ID()] + return cmp.Compare(heightA, heightB) + }) + + // Prune the expression set to unnest down to only those not included in comprehensions. + for idx := 0; idx < len(unnestExprs)-1; idx++ { + e := unnestExprs[idx] + if present, found := unnestMap[e.ID()]; !found || !present { + continue + } + height := heights[e.ID()] + if height < opt.exprUnnestHeight { + continue + } + reduceHeight(heights, e, opt.exprUnnestHeight) + opt.registerBranchVariable(ctx, e) + } + return ruleExpr +} + // registerVariable creates an entry for a variable name within the cel.@block used to enumerate // variables within composed policy expression. func (opt *ruleComposerImpl) registerVariable(ctx *cel.OptimizerContext, v *CompiledVariable) { @@ -168,7 +260,23 @@ func (opt *ruleComposerImpl) registerVariable(ctx *cel.OptimizerContext, v *Comp indexVar: indexVar, localVar: varName, expr: varExpr, - cv: v} + celType: v.Declaration().Type()} + opt.varIndices = append(opt.varIndices, vi) + opt.nextVarIndex++ +} + +// registerBranchVariable creates an entry for a variable name within the cel.@block used to unnest +// a deeply nested logical branch or logical operator. +func (opt *ruleComposerImpl) registerBranchVariable(ctx *cel.OptimizerContext, varExpr ast.NavigableExpr) { + indexVar := fmt.Sprintf("@index%d", opt.nextVarIndex) + varExprCopy := ctx.CopyASTAndMetadata(ctx.NewAST(varExpr)) + vi := varIndex{ + index: opt.nextVarIndex, + indexVar: indexVar, + expr: varExprCopy, + celType: varExpr.Type(), + } + ctx.UpdateExpr(varExpr, ctx.NewIdent(vi.indexVar)) opt.varIndices = append(opt.varIndices, vi) opt.nextVarIndex++ } @@ -270,6 +378,7 @@ func (s nonOptionalCompositionStep) combine(step compositionStep) compositionSte ) } // The `step` is pruned away by a unconditional non-optional step `s`. + // Likely a candidate for dead-code warnings. return s } return newNonOptionalCompositionStep(ctx, @@ -362,3 +471,42 @@ func isOptionalNone(e ast.Expr) bool { e.AsCall().FunctionName() == "optional.none" && len(e.AsCall().Args()) == 0 } + +func removeIneligibleSubExprs(e ast.NavigableExpr, unnestMap map[int64]bool) { + for _, id := range comprehensionSubExprIDs(e) { + if _, found := unnestMap[id]; found { + delete(unnestMap, id) + } + } +} + +func comprehensionSubExprIDs(e ast.NavigableExpr) []int64 { + compre := e.AsComprehension() + // Almost the same as e.Children(), but skips the iteration range + return enumerateExprIDs( + compre.AccuInit().(ast.NavigableExpr), + compre.LoopCondition().(ast.NavigableExpr), + compre.LoopStep().(ast.NavigableExpr), + compre.Result().(ast.NavigableExpr), + ) +} + +func enumerateExprIDs(exprs ...ast.NavigableExpr) []int64 { + ids := make([]int64, 0, len(exprs)) + for _, e := range exprs { + ids = append(ids, e.ID()) + ids = append(ids, enumerateExprIDs(e.Children()...)...) + } + return ids +} + +func reduceHeight(heights map[int64]int, e ast.NavigableExpr, amount int) { + height := heights[e.ID()] + if height < amount { + return + } + heights[e.ID()] = height - amount + if parent, hasParent := e.Parent(); hasParent { + reduceHeight(heights, parent, amount) + } +} diff --git a/policy/helper_test.go b/policy/helper_test.go index 3934301fe..fbe3183a5 100644 --- a/policy/helper_test.go +++ b/policy/helper_test.go @@ -35,7 +35,6 @@ var ( envOpts []cel.EnvOption parseOpts []ParserOption expr string - expr2 string }{ { name: "k8s", @@ -113,6 +112,19 @@ var ( .or(((x > 3) ? optional.of(true) : optional.none()) .or((x > 1) ? optional.of(false) : optional.none()))`, }, + { + name: "unnest", + expr: ` + cel.@block([values.filter(x, x > 2)], + ((@index0.size() == 0) ? false : @index0.all(x, x % 2 == 0)) + ? optional.of("some divisible by 2") + : (values.map(x, x * 3).exists(x, x % 4 == 0) + ? optional.of("at least one divisible by 4") + : (values.map(x, x * x * x).exists(x, x % 6 == 0) + ? optional.of("at least one power of 6") + : optional.none()))) + `, + }, { name: "context_pb", expr: ` @@ -145,7 +157,7 @@ var ( cel.@block([ spec.labels, @index0.filter(l, !(l in resource.labels)), - resource.labels.transformList(l, value, l in @index0 && value != @index0[l], l)], + resource.labels.transformList(l, value, l in @index0 && value != @index0[l], l)], (@index1.size() > 0) ? optional.of("missing one or more required labels: %s".format([@index1])) : ((@index2.size() > 0) @@ -199,6 +211,140 @@ var ( }, } + composerUnnestTests = []struct { + name string + expr string + composed string + composerOpts []ComposerOption + }{ + { + name: "unnest", + composerOpts: []ComposerOption{ExpressionUnnestHeight(2)}, + composed: ` + cel.@block([ + values.filter(x, x > 2), + @index0.size() == 0, + @index1 ? false : @index0.all(x, x % 2 == 0), + values.map(x, x * x * x).exists(x, x % 6 == 0) + ? optional.of("at least one power of 6") + : optional.none(), + values.map(x, x * 3).exists(x, x % 4 == 0) + ? optional.of("at least one divisible by 4") + : @index3], + @index2 ? optional.of("some divisible by 2") : @index4) + `, + }, + { + name: "required_labels", + composerOpts: []ComposerOption{ExpressionUnnestHeight(2)}, + composed: ` + cel.@block([ + spec.labels, + @index0.filter(l, !(l in resource.labels)), + resource.labels.transformList(l, value, l in @index0 && value != @index0[l], l), + @index1.size() > 0, + "missing one or more required labels: %s".format([@index1]), + @index2.size() > 0, + "invalid values provided on one or more labels: %s".format([@index2])], + @index3 ? optional.of(@index4) : (@index5 ? optional.of(@index6) : optional.none())) + `, + }, + { + name: "required_labels", + composerOpts: []ComposerOption{ExpressionUnnestHeight(4)}, + composed: ` + cel.@block([ + spec.labels, + @index0.filter(l, !(l in resource.labels)), + resource.labels.transformList(l, value, l in @index0 && value != @index0[l], l), + (@index2.size() > 0) + ? optional.of("invalid values provided on one or more labels: %s".format([@index2])) + : optional.none() + ], + (@index1.size() > 0) + ? optional.of("missing one or more required labels: %s".format([@index1])) + : @index3)`, + }, + { + name: "nested_rule2", + composerOpts: []ComposerOption{ExpressionUnnestHeight(4)}, + composed: ` + cel.@block([ + ["us", "uk", "es"], + {"us": false, "ru": false, "ir": false}, + resource.origin in @index1 && !(resource.origin in @index0), + !(resource.origin in @index0) ? {"banned": "unconfigured_region"} : {}], + resource.?user.orValue("").startsWith("bad") + ? (@index2 ? {"banned": "restricted_region"} : {"banned": "bad_actor"}) + : @index3)`, + }, + { + name: "nested_rule2", + composerOpts: []ComposerOption{ExpressionUnnestHeight(5)}, + composed: ` + cel.@block([ + ["us", "uk", "es"], + {"us": false, "ru": false, "ir": false}, + (resource.origin in @index1 && !(resource.origin in @index0)) + ? {"banned": "restricted_region"} + : {"banned": "bad_actor"}], + resource.?user.orValue("").startsWith("bad") + ? @index2 + : (!(resource.origin in @index0) + ? {"banned": "unconfigured_region"} + : {}))`, + }, + { + name: "limits", + composerOpts: []ComposerOption{ExpressionUnnestHeight(3)}, + composed: ` + cel.@block([ + "hello", + "goodbye", + "me", + "%s, %s", + @index3.format([@index1, @index2]), + (now.getHours() < 24) ? optional.of(@index4 + "!!!") : optional.none(), + optional.of(@index3.format([@index0, @index2]))], + (now.getHours() >= 20) + ? ((now.getHours() < 21) ? optional.of(@index4 + "!") : + ((now.getHours() < 22) ? optional.of(@index4 + "!!") : @index5)) + : @index6)`, + }, + { + name: "limits", + composerOpts: []ComposerOption{ExpressionUnnestHeight(4)}, + composed: ` + cel.@block([ + "hello", + "goodbye", + "me", + "%s, %s", + @index3.format([@index1, @index2]), + (now.getHours() < 22) ? optional.of(@index4 + "!!") : + ((now.getHours() < 24) ? optional.of(@index4 + "!!!") : optional.none())], + (now.getHours() >= 20) + ? ((now.getHours() < 21) ? optional.of(@index4 + "!") : @index5) + : optional.of(@index3.format([@index0, @index2]))) + `, + }, + { + name: "limits", + composerOpts: []ComposerOption{ExpressionUnnestHeight(5)}, + composed: ` + cel.@block([ + "hello", + "goodbye", + "me", + "%s, %s", + @index3.format([@index1, @index2]), + (now.getHours() < 21) ? optional.of(@index4 + "!") : + ((now.getHours() < 22) ? optional.of(@index4 + "!!") : + ((now.getHours() < 24) ? optional.of(@index4 + "!!!") : optional.none()))], + (now.getHours() >= 20) ? @index5 : optional.of(@index3.format([@index0, @index2])))`, + }, + } + policyErrorTests = []struct { name string err string diff --git a/policy/testdata/unnest/config.yaml b/policy/testdata/unnest/config.yaml new file mode 100644 index 000000000..1891ed689 --- /dev/null +++ b/policy/testdata/unnest/config.yaml @@ -0,0 +1,20 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: "unnest" +variables: + - name: values + type_name: list + params: + - type_name: int diff --git a/policy/testdata/unnest/policy.yaml b/policy/testdata/unnest/policy.yaml new file mode 100644 index 000000000..af63683cf --- /dev/null +++ b/policy/testdata/unnest/policy.yaml @@ -0,0 +1,32 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: unnest +rule: + variables: + - name: even_greater + expression: > + values.filter(x, x > 2) + match: + - condition: > + variables.even_greater.size() == 0 ? false : + variables.even_greater.all(x, x % 2 == 0) + output: > + "some divisible by 2" + - condition: "values.map(x, x * 3).exists(x, x % 4 == 0)" + output: > + "at least one divisible by 4" + - condition: "values.map(x, x * x * x).exists(x, x % 6 == 0)" + output: > + "at least one power of 6" diff --git a/policy/testdata/unnest/tests.yaml b/policy/testdata/unnest/tests.yaml new file mode 100644 index 000000000..9bed7b352 --- /dev/null +++ b/policy/testdata/unnest/tests.yaml @@ -0,0 +1,50 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +description: "Unnest tests unnesting of comprehension sequences" +section: + - name: "divisible by 2" + tests: + - name: "true" + input: + values: + expr: "[4, 6]" + output: > + "some divisible by 2" + - name: "false" + input: + values: + expr: "[1, 3, 5]" + output: "optional.none()" + - name: "empty-set" + input: + values: + expr: "[1, 2]" + output: "optional.none()" + - name: "divisible by 4" + tests: + - name: "true" + input: + values: + expr: "[4, 7]" + output: > + "at least one divisible by 4" + - name: "power of 6" + tests: + - name: "true" + input: + values: + expr: "[6, 7]" + output: > + "at least one power of 6" From 3b3a43818bba0cb1011948d29822db0dd7871f33 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Wed, 26 Feb 2025 14:00:31 -0800 Subject: [PATCH 16/46] Separate unnest optimization from composer to capture type info (#1138) * Separate unnest optimization from composer to capture type info * Simplify the variable tracking during unnest --- policy/compiler_test.go | 3 + policy/composer.go | 150 +++++++++++++++++++++++++++------------- policy/helper_test.go | 9 +++ 3 files changed, 114 insertions(+), 48 deletions(-) diff --git a/policy/compiler_test.go b/policy/compiler_test.go index b318d2d6d..545b01f8b 100644 --- a/policy/compiler_test.go +++ b/policy/compiler_test.go @@ -79,6 +79,9 @@ func TestRuleComposerUnnest(t *testing.T) { if normalize(unparsed) != normalize(tc.composed) { t.Errorf("cel.AstToString() got %s, wanted %s", unparsed, tc.composed) } + if !ast.OutputType().IsEquivalentType(tc.outputType) { + t.Errorf("ast.OutputType() got %v, wanted %v", ast.OutputType(), tc.outputType) + } r.setup(t, env, ast) r.run(t) }) diff --git a/policy/composer.go b/policy/composer.go index 0b9be2a5c..762472487 100644 --- a/policy/composer.go +++ b/policy/composer.go @@ -72,13 +72,21 @@ type RuleComposer struct { // Compose stitches together a set of expressions within a CompiledRule into a single CEL ast. func (c *RuleComposer) Compose(r *CompiledRule) (*cel.Ast, *cel.Issues) { ruleRoot, _ := c.env.Compile("true") - opt := cel.NewStaticOptimizer( - &ruleComposerImpl{ - rule: r, - varIndices: []varIndex{}, - exprUnnestHeight: c.exprUnnestHeight, - }) - return opt.Optimize(c.env, ruleRoot) + composer := &ruleComposerImpl{ + rule: r, + varIndices: []varIndex{}, + } + opt := cel.NewStaticOptimizer(composer) + ast, iss := opt.Optimize(c.env, ruleRoot) + if iss.Err() != nil { + return nil, iss + } + unnester := &ruleUnnesterImpl{ + varIndices: []varIndex{}, + exprUnnestHeight: c.exprUnnestHeight, + } + opt = cel.NewStaticOptimizer(unnester) + return opt.Optimize(c.env, ast) } type varIndex struct { @@ -93,8 +101,6 @@ type ruleComposerImpl struct { rule *CompiledRule nextVarIndex int varIndices []varIndex - - exprUnnestHeight int } // Optimize implements an AST optimizer for CEL which composes an expression graph into a single @@ -103,21 +109,16 @@ func (opt *ruleComposerImpl) Optimize(ctx *cel.OptimizerContext, a *ast.AST) *as // The input to optimize is a dummy expression which is completely replaced according // to the configuration of the rule composition graph. ruleExpr := opt.optimizeRule(ctx, opt.rule) - // If the rule is deeply nested, it may need to be unnested. This process may generate - // additional variables that are included in the `sortedVariables` list. - ruleExpr = opt.maybeUnnestRule(ctx, ruleExpr) - // Collect all variables associated with the rule expression. - allVars := opt.sortedVariables() // If there were no variables, return the expression. - if len(allVars) == 0 { + if len(opt.varIndices) == 0 { return ctx.NewAST(ruleExpr) } // Otherwise populate the cel.@block with the variable declarations and wrap the expression // in the block. - varExprs := make([]ast.Expr, len(allVars)) - for i, vi := range allVars { + varExprs := make([]ast.Expr, len(opt.varIndices)) + for i, vi := range opt.varIndices { varExprs[i] = vi.expr err := ctx.ExtendEnv(cel.Variable(vi.indexVar, vi.celType)) if err != nil { @@ -197,15 +198,90 @@ func (opt *ruleComposerImpl) rewriteVariableName(ctx *cel.OptimizerContext) ast. }) } -func (opt *ruleComposerImpl) maybeUnnestRule(ctx *cel.OptimizerContext, ruleExpr ast.Expr) ast.Expr { - // Split the expr into local variables based on expression height - ruleAST := ctx.NewAST(ruleExpr) - ruleNav := ast.NavigateAST(ruleAST) +// registerVariable creates an entry for a variable name within the cel.@block used to enumerate +// variables within composed policy expression. +func (opt *ruleComposerImpl) registerVariable(ctx *cel.OptimizerContext, v *CompiledVariable) { + varName := fmt.Sprintf("variables.%s", v.Name()) + indexVar := fmt.Sprintf("@index%d", opt.nextVarIndex) + varExpr := ctx.CopyASTAndMetadata(v.Expr().NativeRep()) + ast.PostOrderVisit(varExpr, opt.rewriteVariableName(ctx)) + vi := varIndex{ + index: opt.nextVarIndex, + indexVar: indexVar, + localVar: varName, + expr: varExpr, + celType: v.Declaration().Type()} + opt.varIndices = append(opt.varIndices, vi) + opt.nextVarIndex++ +} + +type ruleUnnesterImpl struct { + nextVarIndex int + varIndices []varIndex + exprUnnestHeight int +} + +func (opt *ruleUnnesterImpl) Optimize(ctx *cel.OptimizerContext, a *ast.AST) *ast.AST { + // Since the optimizer is based on the original environment provided to the composer, + // a second pass on the `cel.@block` will require a rebuilding of the cel environment + ruleExpr := ast.NavigateAST(a) + var varExprs []ast.Expr + var varDecls []cel.EnvOption + if ruleExpr.Kind() == ast.CallKind && ruleExpr.AsCall().FunctionName() == "cel.@block" { + // Extract the expr from the cel.@block, args[1], as a navigable expr value. + // Also extract the variable declarations and all associated types from the cel.@block as + // varIndex values, but without doing any rewrites as the types are all correct already. + block := ruleExpr.AsCall() + ruleExpr = block.Args()[1].(ast.NavigableExpr) + + // Collect the list of variables associated with the block + blockList := block.Args()[0].(ast.NavigableExpr) + vars := blockList.AsList() + varExprs = make([]ast.Expr, vars.Size()) + varDecls = make([]cel.EnvOption, vars.Size()) + copy(varExprs, vars.Elements()) + for i, v := range varExprs { + // Track the variable he varDecls set. + indexVar := fmt.Sprintf("@index%d", i) + celType := a.GetType(v.ID()) + varDecls[i] = cel.Variable(indexVar, celType) + opt.nextVarIndex++ + } + } + if len(varDecls) != 0 { + err := ctx.ExtendEnv(varDecls...) + if err != nil { + ctx.ReportErrorAtID(ruleExpr.ID(), "%s", err.Error()) + } + } + + // Attempt to unnest the rule. + ruleExpr = opt.maybeUnnestRule(ctx, ruleExpr) + // If there were no variables, return the expression. + if len(opt.varIndices) == 0 { + return a + } + + // Otherwise populate the cel.@block with the variable declarations and wrap the expression + // in the block. + for i := 0; i < len(opt.varIndices); i++ { + vi := opt.varIndices[i] + varExprs = append(varExprs, vi.expr) + err := ctx.ExtendEnv(cel.Variable(vi.indexVar, vi.celType)) + if err != nil { + ctx.ReportErrorAtID(ruleExpr.ID(), "%s", err.Error()) + } + } + blockExpr := ctx.NewCall("cel.@block", ctx.NewList(varExprs, []int32{}), ruleExpr) + return ctx.NewAST(blockExpr) +} + +func (opt *ruleUnnesterImpl) maybeUnnestRule(ctx *cel.OptimizerContext, ruleExpr ast.NavigableExpr) ast.NavigableExpr { // Unnest expressions are ordered from leaf to root via the ast.MatchDescendants call. - heights := ast.Heights(ruleAST) + heights := ast.Heights(ast.NewAST(ruleExpr, nil)) unnestMap := map[int64]bool{} unnestExprs := []ast.NavigableExpr{} - ast.MatchDescendants(ruleNav, func(e ast.NavigableExpr) bool { + ast.MatchDescendants(ruleExpr, func(e ast.NavigableExpr) bool { // If the expression is a comprehension, then all unnest candidates captured previously that relate // to the comprehension body should be removed from the list of candidate branches for unnesting. if e.Kind() == ast.ComprehensionKind { @@ -243,31 +319,14 @@ func (opt *ruleComposerImpl) maybeUnnestRule(ctx *cel.OptimizerContext, ruleExpr continue } reduceHeight(heights, e, opt.exprUnnestHeight) - opt.registerBranchVariable(ctx, e) + opt.registerUnnestVariable(ctx, e) } return ruleExpr } -// registerVariable creates an entry for a variable name within the cel.@block used to enumerate -// variables within composed policy expression. -func (opt *ruleComposerImpl) registerVariable(ctx *cel.OptimizerContext, v *CompiledVariable) { - varName := fmt.Sprintf("variables.%s", v.Name()) - indexVar := fmt.Sprintf("@index%d", opt.nextVarIndex) - varExpr := ctx.CopyASTAndMetadata(v.Expr().NativeRep()) - ast.PostOrderVisit(varExpr, opt.rewriteVariableName(ctx)) - vi := varIndex{ - index: opt.nextVarIndex, - indexVar: indexVar, - localVar: varName, - expr: varExpr, - celType: v.Declaration().Type()} - opt.varIndices = append(opt.varIndices, vi) - opt.nextVarIndex++ -} - -// registerBranchVariable creates an entry for a variable name within the cel.@block used to unnest +// registerUnnestVariable creates an entry for a variable name within the cel.@block used to unnest // a deeply nested logical branch or logical operator. -func (opt *ruleComposerImpl) registerBranchVariable(ctx *cel.OptimizerContext, varExpr ast.NavigableExpr) { +func (opt *ruleUnnesterImpl) registerUnnestVariable(ctx *cel.OptimizerContext, varExpr ast.NavigableExpr) { indexVar := fmt.Sprintf("@index%d", opt.nextVarIndex) varExprCopy := ctx.CopyASTAndMetadata(ctx.NewAST(varExpr)) vi := varIndex{ @@ -281,11 +340,6 @@ func (opt *ruleComposerImpl) registerBranchVariable(ctx *cel.OptimizerContext, v opt.nextVarIndex++ } -// sortedVariables returns the variables ordered by their declaration index. -func (opt *ruleComposerImpl) sortedVariables() []varIndex { - return opt.varIndices -} - // compositionStep interface represents an intermediate stage of rule and match expression composition // // The CompiledRule and CompiledMatch types are meant to represent standalone tuples of condition diff --git a/policy/helper_test.go b/policy/helper_test.go index fbe3183a5..8e117331c 100644 --- a/policy/helper_test.go +++ b/policy/helper_test.go @@ -216,6 +216,7 @@ var ( expr string composed string composerOpts []ComposerOption + outputType *cel.Type }{ { name: "unnest", @@ -233,6 +234,7 @@ var ( : @index3], @index2 ? optional.of("some divisible by 2") : @index4) `, + outputType: cel.OptionalType(cel.StringType), }, { name: "required_labels", @@ -248,6 +250,7 @@ var ( "invalid values provided on one or more labels: %s".format([@index2])], @index3 ? optional.of(@index4) : (@index5 ? optional.of(@index6) : optional.none())) `, + outputType: cel.OptionalType(cel.StringType), }, { name: "required_labels", @@ -264,6 +267,7 @@ var ( (@index1.size() > 0) ? optional.of("missing one or more required labels: %s".format([@index1])) : @index3)`, + outputType: cel.OptionalType(cel.StringType), }, { name: "nested_rule2", @@ -277,6 +281,7 @@ var ( resource.?user.orValue("").startsWith("bad") ? (@index2 ? {"banned": "restricted_region"} : {"banned": "bad_actor"}) : @index3)`, + outputType: cel.MapType(cel.StringType, cel.StringType), }, { name: "nested_rule2", @@ -293,6 +298,7 @@ var ( : (!(resource.origin in @index0) ? {"banned": "unconfigured_region"} : {}))`, + outputType: cel.MapType(cel.StringType, cel.StringType), }, { name: "limits", @@ -310,6 +316,7 @@ var ( ? ((now.getHours() < 21) ? optional.of(@index4 + "!") : ((now.getHours() < 22) ? optional.of(@index4 + "!!") : @index5)) : @index6)`, + outputType: cel.OptionalType(cel.StringType), }, { name: "limits", @@ -327,6 +334,7 @@ var ( ? ((now.getHours() < 21) ? optional.of(@index4 + "!") : @index5) : optional.of(@index3.format([@index0, @index2]))) `, + outputType: cel.OptionalType(cel.StringType), }, { name: "limits", @@ -342,6 +350,7 @@ var ( ((now.getHours() < 22) ? optional.of(@index4 + "!!") : ((now.getHours() < 24) ? optional.of(@index4 + "!!!") : optional.none()))], (now.getHours() >= 20) ? @index5 : optional.of(@index3.format([@index0, @index2])))`, + outputType: cel.OptionalType(cel.StringType), }, } From f5ea07b389a114904c1dd11d91f0ad387cb23fb3 Mon Sep 17 00:00:00 2001 From: "zev.ac" Date: Tue, 11 Mar 2025 23:18:52 +0530 Subject: [PATCH 17/46] Expose extension option factory as a public method (#1141) * added tests for extension option factory --- common/env/env.go | 11 +++-- ext/BUILD.bazel | 6 ++- ext/extension_option_factory.go | 72 ++++++++++++++++++++++++++++ ext/extension_option_factory_test.go | 67 ++++++++++++++++++++++++++ policy/config.go | 54 +-------------------- 5 files changed, 151 insertions(+), 59 deletions(-) create mode 100644 ext/extension_option_factory.go create mode 100644 ext/extension_option_factory_test.go diff --git a/common/env/env.go b/common/env/env.go index aa7b066c5..07294c696 100644 --- a/common/env/env.go +++ b/common/env/env.go @@ -115,7 +115,7 @@ func (c *Config) AddVariableDecls(vars ...*decls.VariableDecl) *Config { if v == nil { continue } - convVars[i] = NewVariable(v.Name(), serializeTypeDesc(v.Type())) + convVars[i] = NewVariable(v.Name(), SerializeTypeDesc(v.Type())) } return c.AddVariables(convVars...) } @@ -146,9 +146,9 @@ func (c *Config) AddFunctionDecls(funcs ...*decls.FunctionDecl) *Config { overloadID := o.ID() args := make([]*TypeDesc, 0, len(o.ArgTypes())) for _, a := range o.ArgTypes() { - args = append(args, serializeTypeDesc(a)) + args = append(args, SerializeTypeDesc(a)) } - ret := serializeTypeDesc(o.ResultType()) + ret := SerializeTypeDesc(o.ResultType()) if o.IsMemberFunction() { overloads = append(overloads, NewMemberOverload(overloadID, args[0], args[1:], ret)) } else { @@ -836,7 +836,8 @@ func (td *TypeDesc) AsCELType(tp types.Provider) (*types.Type, error) { } } -func serializeTypeDesc(t *types.Type) *TypeDesc { +// SerializeTypeDesc converts *types.Type to a serialized format TypeDesc +func SerializeTypeDesc(t *types.Type) *TypeDesc { typeName := t.TypeName() if t.Kind() == types.TypeParamKind { return NewTypeParam(typeName) @@ -848,7 +849,7 @@ func serializeTypeDesc(t *types.Type) *TypeDesc { } var params []*TypeDesc for _, p := range t.Parameters() { - params = append(params, serializeTypeDesc(p)) + params = append(params, SerializeTypeDesc(p)) } return NewTypeDesc(typeName, params...) } diff --git a/ext/BUILD.bazel b/ext/BUILD.bazel index b764fa1f5..62863c17a 100644 --- a/ext/BUILD.bazel +++ b/ext/BUILD.bazel @@ -10,6 +10,7 @@ go_library( "bindings.go", "comprehensions.go", "encoders.go", + "extension_option_factory.go", "formatting.go", "guards.go", "lists.go", @@ -26,6 +27,7 @@ go_library( "//checker:go_default_library", "//common/ast:go_default_library", "//common/decls:go_default_library", + "//common/env:go_default_library", "//common/overloads:go_default_library", "//common/operators:go_default_library", "//common/types:go_default_library", @@ -48,7 +50,8 @@ go_test( srcs = [ "bindings_test.go", "comprehensions_test.go", - "encoders_test.go", + "encoders_test.go", + "extension_option_factory_test.go", "lists_test.go", "math_test.go", "native_test.go", @@ -62,6 +65,7 @@ go_test( deps = [ "//cel:go_default_library", "//checker:go_default_library", + "//common/env:go_default_library", "//common/types:go_default_library", "//common/types/ref:go_default_library", "//common/types/traits:go_default_library", diff --git a/ext/extension_option_factory.go b/ext/extension_option_factory.go new file mode 100644 index 000000000..4906227a5 --- /dev/null +++ b/ext/extension_option_factory.go @@ -0,0 +1,72 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ext + +import ( + "fmt" + + "github.com/google/cel-go/cel" + "github.com/google/cel-go/common/env" +) + +// ExtensionOptionFactory converts an ExtensionConfig value to a CEL environment option. +func ExtensionOptionFactory(configElement any) (cel.EnvOption, bool) { + ext, isExtension := configElement.(*env.Extension) + if !isExtension { + return nil, false + } + fac, found := extFactories[ext.Name] + if !found { + return nil, false + } + // If the version is 'latest', set the version value to the max uint. + ver, err := ext.VersionNumber() + if err != nil { + return func(*cel.Env) (*cel.Env, error) { + return nil, fmt.Errorf("invalid extension version: %s - %s", ext.Name, ext.Version) + }, true + } + return fac(ver), true +} + +// extensionFactory accepts a version and produces a CEL environment associated with the versioned extension. +type extensionFactory func(uint32) cel.EnvOption + +var extFactories = map[string]extensionFactory{ + "bindings": func(version uint32) cel.EnvOption { + return Bindings(BindingsVersion(version)) + }, + "encoders": func(version uint32) cel.EnvOption { + return Encoders(EncodersVersion(version)) + }, + "lists": func(version uint32) cel.EnvOption { + return Lists(ListsVersion(version)) + }, + "math": func(version uint32) cel.EnvOption { + return Math(MathVersion(version)) + }, + "protos": func(version uint32) cel.EnvOption { + return Protos(ProtosVersion(version)) + }, + "sets": func(version uint32) cel.EnvOption { + return Sets(SetsVersion(version)) + }, + "strings": func(version uint32) cel.EnvOption { + return Strings(StringsVersion(version)) + }, + "two-var-comprehensions": func(version uint32) cel.EnvOption { + return TwoVarComprehensions(TwoVarComprehensionsVersion(version)) + }, +} diff --git a/ext/extension_option_factory_test.go b/ext/extension_option_factory_test.go new file mode 100644 index 000000000..f721bb6bf --- /dev/null +++ b/ext/extension_option_factory_test.go @@ -0,0 +1,67 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ext + +import ( + "fmt" + "testing" + + "github.com/google/cel-go/cel" + "github.com/google/cel-go/common/env" +) + +func TestExtensionOptionFactoryInvalidExtension(t *testing.T) { + invalidExtension := "invalid extension" + _, validExtension := ExtensionOptionFactory(invalidExtension) + if validExtension { + t.Fatalf("ExtensionOptionFactory(%s) returned valid extension for invalid input", invalidExtension) + } +} + +func TestExtensionOptionFactoryInvalidExtensionName(t *testing.T) { + e := &env.Extension{Name: "invalid extension name"} + _, validExtension := ExtensionOptionFactory(e) + if validExtension { + t.Fatalf("ExtensionOptionFactory(%s) returned valid extension for invalid extension name", e.Name) + } +} + +func TestExtensionOptionFactoryInvalidExtensionVersion(t *testing.T) { + e := &env.Extension{Name: "bindings", Version: "invalid version"} + opt, validExtension := ExtensionOptionFactory(e) + if !validExtension { + t.Fatalf("ExtensionOptionFactory(%s) returned invalid extension", e.Name) + } + _, err := cel.NewCustomEnv(opt) + if err == nil || err.Error() != fmt.Sprintf("invalid extension version: %s - %s", e.Name, e.Version) { + t.Fatalf("ExtensionOptionFactory(%s) returned invalid extension version", e.Name) + } +} + +func TestExtensionOptionFactoryValidBindingsExtension(t *testing.T) { + e := &env.Extension{Name: "bindings", Version: "latest"} + opt, validExtension := ExtensionOptionFactory(e) + if !validExtension { + t.Fatalf("ExtensionOptionFactory(%s) returned invalid extension", e.Name) + } + en, err := cel.NewCustomEnv(opt) + if err != nil { + t.Fatalf("ExtensionOptionFactory(%s) returned invalid extension", e.Name) + } + cfg, err := en.ToConfig("test config") + if len(cfg.Extensions) != 1 || cfg.Extensions[0].Name != "cel.lib.ext.cel.bindings" || cfg.Extensions[0].Version != "latest" { + t.Fatalf("ExtensionOptionFactory(%s) returned invalid extension", e.Name) + } +} diff --git a/policy/config.go b/policy/config.go index 12fbe44a7..02243922b 100644 --- a/policy/config.go +++ b/policy/config.go @@ -15,8 +15,6 @@ package policy import ( - "fmt" - "github.com/google/cel-go/cel" "github.com/google/cel-go/common/env" "github.com/google/cel-go/ext" @@ -28,55 +26,5 @@ import ( // a set of configuration ConfigOptionFactory values to handle extensions and other config features // which may be defined outside of the `cel` package. func FromConfig(config *env.Config) cel.EnvOption { - return cel.FromConfig(config, extensionOptionFactory) -} - -// extensionOptionFactory converts an ExtensionConfig value to a CEL environment option. -func extensionOptionFactory(configElement any) (cel.EnvOption, bool) { - ext, isExtension := configElement.(*env.Extension) - if !isExtension { - return nil, false - } - fac, found := extFactories[ext.Name] - if !found { - return nil, false - } - // If the version is 'latest', set the version value to the max uint. - ver, err := ext.VersionNumber() - if err != nil { - return func(*cel.Env) (*cel.Env, error) { - return nil, fmt.Errorf("invalid extension version: %s - %s", ext.Name, ext.Version) - }, true - } - return fac(ver), true -} - -// extensionFactory accepts a version and produces a CEL environment associated with the versioned extension. -type extensionFactory func(uint32) cel.EnvOption - -var extFactories = map[string]extensionFactory{ - "bindings": func(version uint32) cel.EnvOption { - return ext.Bindings(ext.BindingsVersion(version)) - }, - "encoders": func(version uint32) cel.EnvOption { - return ext.Encoders(ext.EncodersVersion(version)) - }, - "lists": func(version uint32) cel.EnvOption { - return ext.Lists(ext.ListsVersion(version)) - }, - "math": func(version uint32) cel.EnvOption { - return ext.Math(ext.MathVersion(version)) - }, - "protos": func(version uint32) cel.EnvOption { - return ext.Protos(ext.ProtosVersion(version)) - }, - "sets": func(version uint32) cel.EnvOption { - return ext.Sets(ext.SetsVersion(version)) - }, - "strings": func(version uint32) cel.EnvOption { - return ext.Strings(ext.StringsVersion(version)) - }, - "two-var-comprehensions": func(version uint32) cel.EnvOption { - return ext.TwoVarComprehensions(ext.TwoVarComprehensionsVersion(version)) - }, + return cel.FromConfig(config, ext.ExtensionOptionFactory) } From 8890f56dd657d3f4746ad1c53f55b65574457d29 Mon Sep 17 00:00:00 2001 From: Justin King Date: Fri, 21 Mar 2025 11:18:18 -0700 Subject: [PATCH 18/46] Update strings.format to adhere to the specification (#1133) * Update strings.format to adhere to the specification Signed-off-by: Justin King * Remove usage of ext.Format from repl Signed-off-by: Justin King * Update strings.format changes to be versioned Signed-off-by: Justin King * Update ext/BUILD.bazel Signed-off-by: Justin King * Fix repl/evaluator.go Signed-off-by: Justin King * Implement fmt.Stringer for various ref.Value to return CEL syntax string Signed-off-by: Justin King * Fix tests Signed-off-by: Justin King * Move formatting to standalone function Signed-off-by: Justin King * Deal with error, unknown, custom opaques, and custom objects Signed-off-by: Justin King * Add basic docs to exported methods Signed-off-by: Justin King --------- Signed-off-by: Justin King --- WORKSPACE | 4 +- common/types/BUILD.bazel | 1 + common/types/bool.go | 9 + common/types/bytes.go | 15 + common/types/double.go | 22 + common/types/duration.go | 5 + common/types/format.go | 42 + common/types/int.go | 5 + common/types/list.go | 16 + common/types/map.go | 36 + common/types/null.go | 5 + common/types/object.go | 29 + common/types/optional.go | 11 + common/types/string.go | 4 + common/types/timestamp.go | 4 + common/types/types.go | 4 + common/types/uint.go | 6 + conformance/BUILD.bazel | 7 + conformance/conformance_test.go | 1 + conformance/go.mod | 3 +- conformance/go.sum | 12 +- ext/BUILD.bazel | 7 +- ext/formatting.go | 23 + ext/formatting_test.go | 1136 +++++++++++++++++ ext/formatting_v2.go | 788 ++++++++++++ ext/formatting_v2_test.go | 1015 +++++++++++++++ ext/strings.go | 59 +- ext/strings_test.go | 1109 ---------------- go.mod | 4 +- go.sum | 8 +- repl/evaluator.go | 8 +- repl/evaluator_test.go | 16 +- repl/go.mod | 4 +- repl/go.sum | 4 +- vendor/cel.dev/expr/README.md | 2 - vendor/golang.org/x/text/LICENSE | 4 +- .../x/text/internal/catmsg/codec.go | 2 +- vendor/modules.txt | 4 +- 38 files changed, 3258 insertions(+), 1176 deletions(-) create mode 100644 common/types/format.go create mode 100644 ext/formatting_test.go create mode 100644 ext/formatting_v2.go create mode 100644 ext/formatting_v2_test.go diff --git a/WORKSPACE b/WORKSPACE index 133a6b8e3..b52b8319a 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -101,8 +101,8 @@ go_repository( go_repository( name = "dev_cel_expr", importpath = "cel.dev/expr", - sum = "h1:NciYrtDRIR0lNCnH1LFJegdjspNx9fI59O7TWcua/W4=", - version = "v0.19.1", + sum = "h1:o+Wj235dy4gFYlYin3JsMpp3EEfMrPm/6tdoyjT98S0=", + version = "v0.21.2", ) # local_repository( diff --git a/common/types/BUILD.bazel b/common/types/BUILD.bazel index 8f010fae4..7082bc755 100644 --- a/common/types/BUILD.bazel +++ b/common/types/BUILD.bazel @@ -18,6 +18,7 @@ go_library( "int.go", "iterator.go", "json_value.go", + "format.go", "list.go", "map.go", "null.go", diff --git a/common/types/bool.go b/common/types/bool.go index 565734f3f..1f9e10739 100644 --- a/common/types/bool.go +++ b/common/types/bool.go @@ -18,6 +18,7 @@ import ( "fmt" "reflect" "strconv" + "strings" "github.com/google/cel-go/common/types/ref" @@ -128,6 +129,14 @@ func (b Bool) Value() any { return bool(b) } +func (b Bool) format(sb *strings.Builder) { + if b { + sb.WriteString("true") + } else { + sb.WriteString("false") + } +} + // IsBool returns whether the input ref.Val or ref.Type is equal to BoolType. func IsBool(elem ref.Val) bool { switch v := elem.(type) { diff --git a/common/types/bytes.go b/common/types/bytes.go index 7e813e291..b59e1fc20 100644 --- a/common/types/bytes.go +++ b/common/types/bytes.go @@ -19,6 +19,7 @@ import ( "encoding/base64" "fmt" "reflect" + "strings" "unicode/utf8" "github.com/google/cel-go/common/types/ref" @@ -138,3 +139,17 @@ func (b Bytes) Type() ref.Type { func (b Bytes) Value() any { return []byte(b) } + +func (b Bytes) format(sb *strings.Builder) { + fmt.Fprintf(sb, "b\"%s\"", bytesToOctets([]byte(b))) +} + +// bytesToOctets converts byte sequences to a string using a three digit octal encoded value +// per byte. +func bytesToOctets(byteVal []byte) string { + var b strings.Builder + for _, c := range byteVal { + fmt.Fprintf(&b, "\\%03o", c) + } + return b.String() +} diff --git a/common/types/double.go b/common/types/double.go index 027e78978..1e7de9d6e 100644 --- a/common/types/double.go +++ b/common/types/double.go @@ -18,6 +18,8 @@ import ( "fmt" "math" "reflect" + "strconv" + "strings" "github.com/google/cel-go/common/types/ref" @@ -209,3 +211,23 @@ func (d Double) Type() ref.Type { func (d Double) Value() any { return float64(d) } + +func (d Double) format(sb *strings.Builder) { + if math.IsNaN(float64(d)) { + sb.WriteString(`double("NaN")`) + return + } + if math.IsInf(float64(d), -1) { + sb.WriteString(`double("-Infinity")`) + return + } + if math.IsInf(float64(d), 1) { + sb.WriteString(`double("Infinity")`) + return + } + s := strconv.FormatFloat(float64(d), 'f', -1, 64) + sb.WriteString(s) + if !strings.ContainsRune(s, '.') { + sb.WriteString(".0") + } +} diff --git a/common/types/duration.go b/common/types/duration.go index 596e56d6b..be58d567e 100644 --- a/common/types/duration.go +++ b/common/types/duration.go @@ -18,6 +18,7 @@ import ( "fmt" "reflect" "strconv" + "strings" "time" "github.com/google/cel-go/common/overloads" @@ -185,6 +186,10 @@ func (d Duration) Value() any { return d.Duration } +func (d Duration) format(sb *strings.Builder) { + fmt.Fprintf(sb, `duration("%ss")`, strconv.FormatFloat(d.Seconds(), 'f', -1, 64)) +} + // DurationGetHours returns the duration in hours. func DurationGetHours(val ref.Val) ref.Val { dur, ok := val.(Duration) diff --git a/common/types/format.go b/common/types/format.go new file mode 100644 index 000000000..174a2bd04 --- /dev/null +++ b/common/types/format.go @@ -0,0 +1,42 @@ +package types + +import ( + "fmt" + "strings" + + "github.com/google/cel-go/common/types/ref" + "github.com/google/cel-go/common/types/traits" +) + +type formattable interface { + format(*strings.Builder) +} + +// Format formats the value as a string. The result is only intended for human consumption and ignores errors. +// Do not depend on the output being stable. It may change at any time. +func Format(val ref.Val) string { + var sb strings.Builder + formatTo(&sb, val) + return sb.String() +} + +func formatTo(sb *strings.Builder, val ref.Val) { + if fmtable, ok := val.(formattable); ok { + fmtable.format(sb) + return + } + // All of the builtins implement formattable. Try to deal with traits. + if l, ok := val.(traits.Lister); ok { + formatList(l, sb) + return + } + if m, ok := val.(traits.Mapper); ok { + formatMap(m, sb) + return + } + // This could be an error, unknown, opaque or object. + // Unfortunately we have no consistent way of inspecting + // opaque and object. So we just fallback to fmt.Stringer + // and hope it is relavent. + fmt.Fprintf(sb, "%s", val) +} diff --git a/common/types/int.go b/common/types/int.go index 0ae9507c3..0ac1997b7 100644 --- a/common/types/int.go +++ b/common/types/int.go @@ -19,6 +19,7 @@ import ( "math" "reflect" "strconv" + "strings" "time" "github.com/google/cel-go/common/types/ref" @@ -290,6 +291,10 @@ func (i Int) Value() any { return int64(i) } +func (i Int) format(sb *strings.Builder) { + sb.WriteString(strconv.FormatInt(int64(i), 10)) +} + // isJSONSafe indicates whether the int is safely representable as a floating point value in JSON. func (i Int) isJSONSafe() bool { return i >= minIntJSON && i <= maxIntJSON diff --git a/common/types/list.go b/common/types/list.go index 7e68a5daf..8c023f891 100644 --- a/common/types/list.go +++ b/common/types/list.go @@ -299,6 +299,22 @@ func (l *baseList) String() string { return sb.String() } +func formatList(l traits.Lister, sb *strings.Builder) { + sb.WriteString("[") + n, _ := l.Size().(Int) + for i := 0; i < int(n); i++ { + formatTo(sb, l.Get(Int(i))) + if i != int(n)-1 { + sb.WriteString(", ") + } + } + sb.WriteString("]") +} + +func (l *baseList) format(sb *strings.Builder) { + formatList(l, sb) +} + // mutableList aggregates values into its internal storage. For use with internal CEL variables only. type mutableList struct { *baseList diff --git a/common/types/map.go b/common/types/map.go index cb6cce78b..b33096197 100644 --- a/common/types/map.go +++ b/common/types/map.go @@ -17,6 +17,7 @@ package types import ( "fmt" "reflect" + "sort" "strings" "github.com/stoewer/go-strcase" @@ -318,6 +319,41 @@ func (m *baseMap) String() string { return sb.String() } +type baseMapEntry struct { + key string + val string +} + +func formatMap(m traits.Mapper, sb *strings.Builder) { + it := m.Iterator() + var ents []baseMapEntry + if s, ok := m.Size().(Int); ok { + ents = make([]baseMapEntry, 0, int(s)) + } + for it.HasNext() == True { + k := it.Next() + v, _ := m.Find(k) + ents = append(ents, baseMapEntry{Format(k), Format(v)}) + } + sort.SliceStable(ents, func(i, j int) bool { + return ents[i].key < ents[j].key + }) + sb.WriteString("{") + for i, ent := range ents { + if i > 0 { + sb.WriteString(", ") + } + sb.WriteString(ent.key) + sb.WriteString(": ") + sb.WriteString(ent.val) + } + sb.WriteString("}") +} + +func (m *baseMap) format(sb *strings.Builder) { + formatMap(m, sb) +} + // Type implements the ref.Val interface method. func (m *baseMap) Type() ref.Type { return MapType diff --git a/common/types/null.go b/common/types/null.go index 36514ff20..2c0297fe6 100644 --- a/common/types/null.go +++ b/common/types/null.go @@ -17,6 +17,7 @@ package types import ( "fmt" "reflect" + "strings" "google.golang.org/protobuf/proto" @@ -117,3 +118,7 @@ func (n Null) Type() ref.Type { func (n Null) Value() any { return structpb.NullValue_NULL_VALUE } + +func (n Null) format(sb *strings.Builder) { + sb.WriteString("null") +} diff --git a/common/types/object.go b/common/types/object.go index 5377bff8d..776f6954a 100644 --- a/common/types/object.go +++ b/common/types/object.go @@ -17,9 +17,12 @@ package types import ( "fmt" "reflect" + "sort" + "strings" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" "github.com/google/cel-go/common/types/pb" "github.com/google/cel-go/common/types/ref" @@ -163,3 +166,29 @@ func (o *protoObj) Type() ref.Type { func (o *protoObj) Value() any { return o.value } + +type protoObjField struct { + fd protoreflect.FieldDescriptor + v protoreflect.Value +} + +func (o *protoObj) format(sb *strings.Builder) { + var fields []protoreflect.FieldDescriptor + o.value.ProtoReflect().Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool { + fields = append(fields, fd) + return true + }) + sort.SliceStable(fields, func(i, j int) bool { + return fields[i].Number() < fields[j].Number() + }) + sb.WriteString(o.Type().TypeName()) + sb.WriteString("{") + for i, field := range fields { + if i > 0 { + sb.WriteString(", ") + } + sb.WriteString(fmt.Sprintf("%s: ", field.Name())) + formatTo(sb, o.Get(String(field.Name()))) + } + sb.WriteString("}") +} diff --git a/common/types/optional.go b/common/types/optional.go index 97845a740..b8685ebf5 100644 --- a/common/types/optional.go +++ b/common/types/optional.go @@ -18,6 +18,7 @@ import ( "errors" "fmt" "reflect" + "strings" "github.com/google/cel-go/common/types/ref" ) @@ -94,6 +95,16 @@ func (o *Optional) String() string { return "optional.none()" } +func (o *Optional) format(sb *strings.Builder) { + if o.HasValue() { + sb.WriteString(`optional.of(`) + formatTo(sb, o.GetValue()) + sb.WriteString(`)`) + } else { + sb.WriteString("optional.none()") + } +} + // Type implements the ref.Val interface method. func (o *Optional) Type() ref.Type { return OptionalType diff --git a/common/types/string.go b/common/types/string.go index 3a93743f2..8aad4701c 100644 --- a/common/types/string.go +++ b/common/types/string.go @@ -186,6 +186,10 @@ func (s String) Value() any { return string(s) } +func (s String) format(sb *strings.Builder) { + sb.WriteString(strconv.Quote(string(s))) +} + // StringContains returns whether the string contains a substring. func StringContains(s, sub ref.Val) ref.Val { str, ok := s.(String) diff --git a/common/types/timestamp.go b/common/types/timestamp.go index 33acdea8e..f7be58591 100644 --- a/common/types/timestamp.go +++ b/common/types/timestamp.go @@ -179,6 +179,10 @@ func (t Timestamp) Value() any { return t.Time } +func (t Timestamp) format(sb *strings.Builder) { + fmt.Fprintf(sb, `timestamp("%s")`, t.Time.UTC().Format(time.RFC3339Nano)) +} + var ( timestampValueType = reflect.TypeOf(&tpb.Timestamp{}) diff --git a/common/types/types.go b/common/types/types.go index d5ce60f16..3ed514093 100644 --- a/common/types/types.go +++ b/common/types/types.go @@ -376,6 +376,10 @@ func (t *Type) TypeName() string { return t.runtimeTypeName } +func (t *Type) format(sb *strings.Builder) { + sb.WriteString(t.TypeName()) +} + // WithTraits creates a copy of the current Type and sets the trait mask to the traits parameter. // // This method should be used with Opaque types where the type acts like a container, e.g. vector. diff --git a/common/types/uint.go b/common/types/uint.go index 6d74f30d8..a93405a13 100644 --- a/common/types/uint.go +++ b/common/types/uint.go @@ -19,6 +19,7 @@ import ( "math" "reflect" "strconv" + "strings" "github.com/google/cel-go/common/types/ref" @@ -250,6 +251,11 @@ func (i Uint) Value() any { return uint64(i) } +func (i Uint) format(sb *strings.Builder) { + sb.WriteString(strconv.FormatUint(uint64(i), 10)) + sb.WriteString("u") +} + // isJSONSafe indicates whether the uint is safely representable as a floating point value in JSON. func (i Uint) isJSONSafe() bool { return i <= maxIntJSON diff --git a/conformance/BUILD.bazel b/conformance/BUILD.bazel index 4bca746d2..6b2e50bcf 100644 --- a/conformance/BUILD.bazel +++ b/conformance/BUILD.bazel @@ -55,6 +55,13 @@ _TESTS_TO_SKIP = [ # Future enhancments. "enums/strong_proto2", "enums/strong_proto3", + + # Type deductions + "type_deductions/wrappers/wrapper_promotion_2", + "type_deductions/legacy_nullable_types/null_assignable_to_message_parameter_candidate", + "type_deductions/legacy_nullable_types/null_assignable_to_duration_parameter_candidate", + "type_deductions/legacy_nullable_types/null_assignable_to_timestamp_parameter_candidate", + "type_deductions/legacy_nullable_types/null_assignable_to_abstract_parameter_candidate", ] go_test( diff --git a/conformance/conformance_test.go b/conformance/conformance_test.go index 4eead5c4d..4b37ea2e8 100644 --- a/conformance/conformance_test.go +++ b/conformance/conformance_test.go @@ -92,6 +92,7 @@ func init() { ext.Protos(), ext.Strings(), cel.Lib(celBlockLib{}), + cel.EnableIdentifierEscapeSyntax(), } var err error diff --git a/conformance/go.mod b/conformance/go.mod index 65aa6edd6..b666bfe67 100644 --- a/conformance/go.mod +++ b/conformance/go.mod @@ -3,7 +3,7 @@ module github.com/google/cel-go/conformance go 1.21.1 require ( - cel.dev/expr v0.19.1 + cel.dev/expr v0.21.2 github.com/bazelbuild/rules_go v0.49.0 github.com/google/cel-go v0.21.0 github.com/google/go-cmp v0.6.0 @@ -14,7 +14,6 @@ require ( github.com/antlr4-go/antlr/v4 v4.13.0 // indirect github.com/stoewer/go-strcase v1.2.0 // indirect golang.org/x/exp v0.0.0-20230515195305-f3d0a9c9a5cc // indirect - golang.org/x/text v0.16.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240826202546-f6391c0de4c7 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240826202546-f6391c0de4c7 // indirect ) diff --git a/conformance/go.sum b/conformance/go.sum index 8e4f44922..cadf61fd9 100644 --- a/conformance/go.sum +++ b/conformance/go.sum @@ -1,9 +1,5 @@ -cel.dev/expr v0.18.0 h1:CJ6drgk+Hf96lkLikr4rFf19WrU0BOWEihyZnI2TAzo= -cel.dev/expr v0.18.0/go.mod h1:MrpN08Q+lEBs+bGYdLxxHkZoUSsCp0nSKTs0nTymJgw= -cel.dev/expr v0.19.0 h1:lXuo+nDhpyJSpWxpPVi5cPUwzKb+dsdOiw6IreM5yt0= -cel.dev/expr v0.19.0/go.mod h1:MrpN08Q+lEBs+bGYdLxxHkZoUSsCp0nSKTs0nTymJgw= -cel.dev/expr v0.19.1 h1:NciYrtDRIR0lNCnH1LFJegdjspNx9fI59O7TWcua/W4= -cel.dev/expr v0.19.1/go.mod h1:MrpN08Q+lEBs+bGYdLxxHkZoUSsCp0nSKTs0nTymJgw= +cel.dev/expr v0.21.2 h1:o+Wj235dy4gFYlYin3JsMpp3EEfMrPm/6tdoyjT98S0= +cel.dev/expr v0.21.2/go.mod h1:MrpN08Q+lEBs+bGYdLxxHkZoUSsCp0nSKTs0nTymJgw= github.com/antlr4-go/antlr/v4 v4.13.0 h1:lxCg3LAv+EUK6t1i0y1V6/SLeUi0eKEKdhQAlS8TVTI= github.com/antlr4-go/antlr/v4 v4.13.0/go.mod h1:pfChB/xh/Unjila75QW7+VU4TSnWnnk9UTnmpPaOR2g= github.com/bazelbuild/rules_go v0.49.0 h1:5vCbuvy8Q11g41lseGJDc5vxhDjJtfxr6nM/IC4VmqM= @@ -21,8 +17,6 @@ github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= golang.org/x/exp v0.0.0-20230515195305-f3d0a9c9a5cc h1:mCRnTeVUjcrhlRmO0VK8a6k6Rrf6TF9htwo2pJVSjIU= golang.org/x/exp v0.0.0-20230515195305-f3d0a9c9a5cc/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= -golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= -golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= google.golang.org/genproto/googleapis/api v0.0.0-20240826202546-f6391c0de4c7 h1:YcyjlL1PRr2Q17/I0dPk2JmYS5CDXfcdb2Z3YRioEbw= google.golang.org/genproto/googleapis/api v0.0.0-20240826202546-f6391c0de4c7/go.mod h1:OCdP9MfskevB/rbYvHTsXTtKC+3bHWajPdoKgjcYkfo= google.golang.org/genproto/googleapis/rpc v0.0.0-20240826202546-f6391c0de4c7 h1:2035KHhUv+EpyB+hWgJnaWKJOdX1E95w2S8Rr4uWKTs= @@ -32,3 +26,5 @@ google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWn gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/ext/BUILD.bazel b/ext/BUILD.bazel index 62863c17a..24066b864 100644 --- a/ext/BUILD.bazel +++ b/ext/BUILD.bazel @@ -12,6 +12,7 @@ go_library( "encoders.go", "extension_option_factory.go", "formatting.go", + "formatting_v2.go", "guards.go", "lists.go", "math.go", @@ -28,8 +29,8 @@ go_library( "//common/ast:go_default_library", "//common/decls:go_default_library", "//common/env:go_default_library", - "//common/overloads:go_default_library", "//common/operators:go_default_library", + "//common/overloads:go_default_library", "//common/types:go_default_library", "//common/types/pb:go_default_library", "//common/types/ref:go_default_library", @@ -52,6 +53,8 @@ go_test( "comprehensions_test.go", "encoders_test.go", "extension_option_factory_test.go", + "formatting_test.go", + "formatting_v2_test.go", "lists_test.go", "math_test.go", "native_test.go", @@ -72,8 +75,8 @@ go_test( "//test:go_default_library", "//test/proto2pb:go_default_library", "//test/proto3pb:go_default_library", + "@org_golang_google_protobuf//encoding/protojson:go_default_library", "@org_golang_google_protobuf//proto:go_default_library", "@org_golang_google_protobuf//types/known/wrapperspb:go_default_library", - "@org_golang_google_protobuf//encoding/protojson:go_default_library", ], ) diff --git a/ext/formatting.go b/ext/formatting.go index aa334ccd9..111184b73 100644 --- a/ext/formatting.go +++ b/ext/formatting.go @@ -268,14 +268,17 @@ func makeMatcher(locale string) (language.Matcher, error) { type stringFormatter struct{} +// String implements formatStringInterpolator.String. func (c *stringFormatter) String(arg ref.Val, locale string) (string, error) { return FormatString(arg, locale) } +// Decimal implements formatStringInterpolator.Decimal. func (c *stringFormatter) Decimal(arg ref.Val, locale string) (string, error) { return formatDecimal(arg, locale) } +// Fixed implements formatStringInterpolator.Fixed. func (c *stringFormatter) Fixed(precision *int) func(ref.Val, string) (string, error) { if precision == nil { precision = new(int) @@ -307,6 +310,7 @@ func (c *stringFormatter) Fixed(precision *int) func(ref.Val, string) (string, e } } +// Scientific implements formatStringInterpolator.Scientific. func (c *stringFormatter) Scientific(precision *int) func(ref.Val, string) (string, error) { if precision == nil { precision = new(int) @@ -337,6 +341,7 @@ func (c *stringFormatter) Scientific(precision *int) func(ref.Val, string) (stri } } +// Binary implements formatStringInterpolator.Binary. func (c *stringFormatter) Binary(arg ref.Val, locale string) (string, error) { switch arg.Type() { case types.IntType: @@ -358,6 +363,7 @@ func (c *stringFormatter) Binary(arg ref.Val, locale string) (string, error) { } } +// Hex implements formatStringInterpolator.Hex. func (c *stringFormatter) Hex(useUpper bool) func(ref.Val, string) (string, error) { return func(arg ref.Val, locale string) (string, error) { fmtStr := "%x" @@ -388,6 +394,7 @@ func (c *stringFormatter) Hex(useUpper bool) func(ref.Val, string) (string, erro } } +// Octal implements formatStringInterpolator.Octal. func (c *stringFormatter) Octal(arg ref.Val, locale string) (string, error) { switch arg.Type() { case types.IntType: @@ -504,6 +511,7 @@ type stringFormatChecker struct { ast *ast.AST } +// String implements formatStringInterpolator.String. func (c *stringFormatChecker) String(arg ref.Val, locale string) (string, error) { formatArg := c.args[c.currArgIndex] valid, badID := c.verifyString(formatArg) @@ -513,6 +521,7 @@ func (c *stringFormatChecker) String(arg ref.Val, locale string) (string, error) return "", nil } +// Decimal implements formatStringInterpolator.Decimal. func (c *stringFormatChecker) Decimal(arg ref.Val, locale string) (string, error) { id := c.args[c.currArgIndex].ID() valid := c.verifyTypeOneOf(id, types.IntType, types.UintType) @@ -522,6 +531,7 @@ func (c *stringFormatChecker) Decimal(arg ref.Val, locale string) (string, error return "", nil } +// Fixed implements formatStringInterpolator.Fixed. func (c *stringFormatChecker) Fixed(precision *int) func(ref.Val, string) (string, error) { return func(arg ref.Val, locale string) (string, error) { id := c.args[c.currArgIndex].ID() @@ -534,6 +544,7 @@ func (c *stringFormatChecker) Fixed(precision *int) func(ref.Val, string) (strin } } +// Scientific implements formatStringInterpolator.Scientific. func (c *stringFormatChecker) Scientific(precision *int) func(ref.Val, string) (string, error) { return func(arg ref.Val, locale string) (string, error) { id := c.args[c.currArgIndex].ID() @@ -545,6 +556,7 @@ func (c *stringFormatChecker) Scientific(precision *int) func(ref.Val, string) ( } } +// Binary implements formatStringInterpolator.Binary. func (c *stringFormatChecker) Binary(arg ref.Val, locale string) (string, error) { id := c.args[c.currArgIndex].ID() valid := c.verifyTypeOneOf(id, types.IntType, types.UintType, types.BoolType) @@ -554,6 +566,7 @@ func (c *stringFormatChecker) Binary(arg ref.Val, locale string) (string, error) return "", nil } +// Hex implements formatStringInterpolator.Hex. func (c *stringFormatChecker) Hex(useUpper bool) func(ref.Val, string) (string, error) { return func(arg ref.Val, locale string) (string, error) { id := c.args[c.currArgIndex].ID() @@ -565,6 +578,7 @@ func (c *stringFormatChecker) Hex(useUpper bool) func(ref.Val, string) (string, } } +// Octal implements formatStringInterpolator.Octal. func (c *stringFormatChecker) Octal(arg ref.Val, locale string) (string, error) { id := c.args[c.currArgIndex].ID() valid := c.verifyTypeOneOf(id, types.IntType, types.UintType) @@ -574,6 +588,7 @@ func (c *stringFormatChecker) Octal(arg ref.Val, locale string) (string, error) return "", nil } +// Arg implements formatListArgs.Arg. func (c *stringFormatChecker) Arg(index int64) (ref.Val, error) { c.argsRequested++ c.currArgIndex = index @@ -582,6 +597,7 @@ func (c *stringFormatChecker) Arg(index int64) (ref.Val, error) { return types.Int(0), nil } +// Size implements formatListArgs.Size. func (c *stringFormatChecker) Size() int64 { return int64(len(c.args)) } @@ -686,10 +702,12 @@ func newFormatError(id int64, msg string, args ...any) error { } } +// Error implements error. func (e formatError) Error() string { return e.msg } +// Is implements errors.Is. func (e formatError) Is(target error) bool { return e.msg == target.Error() } @@ -699,6 +717,7 @@ type stringArgList struct { args traits.Lister } +// Arg implements formatListArgs.Arg. func (c *stringArgList) Arg(index int64) (ref.Val, error) { if index >= c.args.Size().Value().(int64) { return nil, fmt.Errorf("index %d out of range", index) @@ -706,6 +725,7 @@ func (c *stringArgList) Arg(index int64) (ref.Val, error) { return c.args.Get(types.Int(index)), nil } +// Size implements formatListArgs.Size. func (c *stringArgList) Size() int64 { return c.args.Size().Value().(int64) } @@ -887,14 +907,17 @@ func newParseFormatError(msg string, wrapped error) error { return parseFormatError{msg: msg, wrapped: wrapped} } +// Error implements error. func (e parseFormatError) Error() string { return fmt.Sprintf("%s: %s", e.msg, e.wrapped.Error()) } +// Is implements errors.Is. func (e parseFormatError) Is(target error) bool { return e.Error() == target.Error() } +// Is implements errors.Unwrap. func (e parseFormatError) Unwrap() error { return e.wrapped } diff --git a/ext/formatting_test.go b/ext/formatting_test.go new file mode 100644 index 000000000..6b1f25066 --- /dev/null +++ b/ext/formatting_test.go @@ -0,0 +1,1136 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ext + +import ( + "fmt" + "math" + "reflect" + "strings" + "testing" + "time" + + "google.golang.org/protobuf/proto" + + "github.com/google/cel-go/cel" + "github.com/google/cel-go/checker" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" + + proto3pb "github.com/google/cel-go/test/proto3pb" +) + +func TestStringFormat(t *testing.T) { + tests := []struct { + name string + format string + dynArgs map[string]any + formatArgs string + locale string + err string + expectedOutput string + expectedRuntimeCost uint64 + expectedEstimatedCost checker.CostEstimate + skipCompileCheck bool + }{ + { + name: "no-op", + format: "no substitution", + expectedOutput: "no substitution", + expectedRuntimeCost: 12, + expectedEstimatedCost: checker.CostEstimate{Min: 12, Max: 12}, + }, + + { + name: "mid-string substitution", + format: "str is %s and some more", + formatArgs: `"filler"`, + expectedOutput: "str is filler and some more", + expectedRuntimeCost: 13, + expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13}, + }, + { + name: "percent escaping", + format: "%% and also %%", + expectedOutput: "% and also %", + expectedRuntimeCost: 12, + expectedEstimatedCost: checker.CostEstimate{Min: 12, Max: 12}, + }, + { + name: "substution inside escaped percent signs", + format: "%%%s%%", + formatArgs: `"text"`, + expectedOutput: "%text%", + expectedRuntimeCost: 11, + expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, + }, + { + name: "substitution with one escaped percent sign on the right", + format: "%s%%", + formatArgs: `"percent on the right"`, + expectedOutput: "percent on the right%", + expectedRuntimeCost: 11, + expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, + }, + { + name: "substitution with one escaped percent sign on the left", + format: "%%%s", + formatArgs: `"percent on the left"`, + expectedOutput: "%percent on the left", + expectedRuntimeCost: 11, + expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, + }, + { + name: "multiple substitutions", + format: "%d %d %d, %s %s %s, %d %d %d, %s %s %s", + formatArgs: `1, 2, 3, "A", "B", "C", 4, 5, 6, "D", "E", "F"`, + expectedOutput: "1 2 3, A B C, 4 5 6, D E F", + expectedRuntimeCost: 14, + expectedEstimatedCost: checker.CostEstimate{Min: 14, Max: 14}, + }, + { + name: "percent sign escape sequence support", + format: "\u0025\u0025escaped \u0025s\u0025\u0025", + formatArgs: `"percent"`, + expectedOutput: "%escaped percent%", + expectedRuntimeCost: 12, + expectedEstimatedCost: checker.CostEstimate{Min: 12, Max: 12}, + }, + { + name: "fixed point formatting clause", + format: "%.3f", + formatArgs: "1.2345", + expectedOutput: "1.234", + expectedRuntimeCost: 11, + expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, + locale: "en_US", + }, + { + name: "binary formatting clause", + format: "this is 5 in binary: %b", + formatArgs: "5", + expectedOutput: "this is 5 in binary: 101", + expectedRuntimeCost: 13, + expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13}, + }, + { + name: "negative binary formatting clause", + format: "this is -5 in binary: %b", + formatArgs: "-5", + expectedOutput: "this is -5 in binary: -101", + expectedRuntimeCost: 13, + expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13}, + }, + { + name: "uint support for binary formatting", + format: "unsigned 64 in binary: %b", + formatArgs: "uint(64)", + expectedOutput: "unsigned 64 in binary: 1000000", + expectedRuntimeCost: 14, + expectedEstimatedCost: checker.CostEstimate{Min: 14, Max: 14}, + }, + { + name: "bool support for binary formatting", + format: "bit set from bool: %b", + formatArgs: "true", + expectedOutput: "bit set from bool: 1", + expectedRuntimeCost: 13, + expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13}, + }, + { + name: "octal formatting clause", + format: "%o", + formatArgs: "11", + expectedOutput: "13", + expectedRuntimeCost: 11, + expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, + }, + { + name: "negative octal formatting clause", + format: "%o", + formatArgs: "-11", + expectedOutput: "-13", + expectedRuntimeCost: 11, + expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, + }, + { + name: "uint support for octal formatting clause", + format: "this is an unsigned octal: %o", + formatArgs: "uint(65535)", + expectedOutput: "this is an unsigned octal: 177777", + expectedRuntimeCost: 14, + expectedEstimatedCost: checker.CostEstimate{Min: 14, Max: 14}, + }, + { + name: "lowercase hexadecimal formatting clause", + format: "%x is 30 in hexadecimal", + formatArgs: "30", + expectedOutput: "1e is 30 in hexadecimal", + expectedRuntimeCost: 13, + expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13}, + }, + { + name: "uppercase hexadecimal formatting clause", + format: "%X is 20 in hexadecimal", + formatArgs: "30", + expectedOutput: "1E is 20 in hexadecimal", + expectedRuntimeCost: 13, + expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13}, + }, + { + name: "negative hexadecimal formatting clause", + format: "%x is -30 in hexadecimal", + formatArgs: "-30", + expectedOutput: "-1e is -30 in hexadecimal", + expectedRuntimeCost: 13, + expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13}, + }, + { + name: "unsigned support for hexadecimal formatting clause", + format: "%X is 6000 in hexadecimal", + formatArgs: "uint(6000)", + expectedOutput: "1770 is 6000 in hexadecimal", + expectedRuntimeCost: 14, + expectedEstimatedCost: checker.CostEstimate{Min: 14, Max: 14}, + }, + { + name: "string support with hexadecimal formatting clause", + format: "%x", + formatArgs: `"Hello world!"`, + expectedOutput: "48656c6c6f20776f726c6421", + expectedRuntimeCost: 11, + expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, + }, + { + name: "string support with uppercase hexadecimal formatting clause", + format: "%X", + formatArgs: `"Hello world!"`, + expectedOutput: "48656C6C6F20776F726C6421", + expectedRuntimeCost: 11, + expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, + }, + { + name: "byte support with hexadecimal formatting clause", + format: "%x", + formatArgs: `b"byte string"`, + expectedOutput: "6279746520737472696e67", + expectedRuntimeCost: 11, + expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, + }, + { + name: "byte support with hexadecimal formatting clause leading zero", + format: "%x", + formatArgs: `b"\x00\x00byte string\x00"`, + expectedOutput: "00006279746520737472696e6700", + expectedRuntimeCost: 11, + expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, + }, + { + name: "byte support with uppercase hexadecimal formatting clause", + format: "%X", + formatArgs: `b"byte string"`, + expectedOutput: "6279746520737472696E67", + expectedRuntimeCost: 11, + expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, + }, + { + name: "scientific notation formatting clause", + format: "%.6e", + formatArgs: "1052.032911275", + expectedOutput: "1.052033\u202f\u00d7\u202f10\u2070\u00b3", + expectedRuntimeCost: 11, + expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, + locale: "en_US", + }, + { + name: "locale support", + format: "%.3f", + formatArgs: "3.14", + locale: "fr_FR", + expectedOutput: "3,140", + expectedRuntimeCost: 11, + expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, + }, + { + name: "default precision for fixed-point clause", + format: "%f", + formatArgs: "2.71828", + expectedOutput: "2.718280", + expectedRuntimeCost: 11, + expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, + locale: "en_US", + }, + { + name: "default precision for scientific notation", + format: "%e", + formatArgs: "2.71828", + expectedOutput: "2.718280\u202f\u00d7\u202f10\u2070\u2070", + expectedRuntimeCost: 11, + expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, + locale: "en_US", + }, + { + name: "default precision for string", + format: "%s", + formatArgs: "2.71", + expectedOutput: "2.71", + expectedRuntimeCost: 11, + expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, + locale: "en_US", + }, + { + name: "default list precision for string", + format: "%s", + formatArgs: "[2.71]", + expectedOutput: "[2.710000]", + expectedRuntimeCost: 21, + expectedEstimatedCost: checker.CostEstimate{Min: 21, Max: 21}, + locale: "en_US", + }, + { + name: "default scientific notation for string", + format: "%s", + formatArgs: "0.000000002", + expectedOutput: "2e-09", + expectedRuntimeCost: 11, + expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, + locale: "en_US", + }, + { + name: "default list scientific notation for string", + format: "%s", + formatArgs: "[0.000000002]", + expectedOutput: "[0.000000]", + expectedRuntimeCost: 21, + expectedEstimatedCost: checker.CostEstimate{Min: 21, Max: 21}, + locale: "en_US", + }, + { + name: "unicode output for scientific notation", + format: "unescaped unicode: %e, escaped unicode: %e", + formatArgs: "2.71828, 2.71828", + expectedOutput: "unescaped unicode: 2.718280 × 10⁰⁰, escaped unicode: 2.718280\u202f\u00d7\u202f10\u2070\u2070", + expectedRuntimeCost: 15, + expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, + locale: "en_US", + }, + { + name: "NaN support for fixed-point", + format: "%f", + formatArgs: `"NaN"`, + expectedOutput: "NaN", + expectedRuntimeCost: 11, + expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, + locale: "en_US", + }, + { + name: "positive infinity support for fixed-point", + format: "%f", + formatArgs: `"Infinity"`, + expectedOutput: "∞", + expectedRuntimeCost: 11, + expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, + locale: "en_US", + }, + { + name: "negative infinity support for fixed-point", + format: "%f", + formatArgs: `"-Infinity"`, + expectedOutput: "-∞", + expectedRuntimeCost: 11, + expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, + locale: "en_US", + }, + { + name: "NaN support for string", + format: "%s", + formatArgs: `double("NaN")`, + expectedOutput: "NaN", + }, + { + name: "positive infinity support for string", + format: "%s", + formatArgs: `double("Inf")`, + expectedOutput: "+Inf", + }, + { + name: "negative infinity support for string", + format: "%s", + formatArgs: `double("-Inf")`, + expectedOutput: "-Inf", + }, + { + name: "infinity list support for string", + format: "%s", + formatArgs: `[double("NaN"),double("+Inf"), double("-Inf")]`, + expectedOutput: `["NaN", "+Inf", "-Inf"]`, + }, + { + name: "uint support for decimal clause", + format: "%d", + formatArgs: "uint(64)", + expectedOutput: "64", + expectedRuntimeCost: 12, + expectedEstimatedCost: checker.CostEstimate{Min: 12, Max: 12}, + }, + { + name: "null support for string", + format: "null: %s", + formatArgs: "null", + expectedOutput: "null: null", + expectedRuntimeCost: 11, + expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, + }, + { + name: "int support for string", + format: "%s", + formatArgs: `999999999999`, + expectedOutput: "999999999999", + expectedRuntimeCost: 11, + expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, + }, + { + name: "bytes support for string", + format: "some bytes: %s", + formatArgs: `b"xyz"`, + expectedOutput: "some bytes: xyz", + expectedRuntimeCost: 12, + expectedEstimatedCost: checker.CostEstimate{Min: 12, Max: 12}, + }, + { + name: "type() support for string", + format: "type is %s", + formatArgs: `type("test string")`, + expectedOutput: "type is string", + expectedRuntimeCost: 12, + expectedEstimatedCost: checker.CostEstimate{Min: 12, Max: 12}, + }, + { + name: "timestamp support for string", + format: "%s", + formatArgs: `timestamp("2023-02-03T23:31:20+00:00")`, + expectedOutput: "2023-02-03T23:31:20Z", + expectedRuntimeCost: 12, + expectedEstimatedCost: checker.CostEstimate{Min: 12, Max: 12}, + }, + { + name: "duration support for string", + format: "%s", + formatArgs: `duration("1h45m47s")`, + expectedOutput: "6347s", + expectedRuntimeCost: 12, + expectedEstimatedCost: checker.CostEstimate{Min: 12, Max: 12}, + }, + { + name: "small duration support for string", + format: "%s", + formatArgs: `duration("2ns")`, + expectedOutput: "0.000000002s", + expectedRuntimeCost: 12, + expectedEstimatedCost: checker.CostEstimate{Min: 12, Max: 12}, + }, + { + name: "list support for string", + format: "%s", + formatArgs: `["abc", 3.14, null, [9, 8, 7, 6], timestamp("2023-02-03T23:31:20Z")]`, + expectedOutput: `["abc", 3.140000, null, [9, 8, 7, 6], timestamp("2023-02-03T23:31:20Z")]`, + expectedRuntimeCost: 32, + expectedEstimatedCost: checker.CostEstimate{Min: 32, Max: 32}, + }, + { + name: "map support for string", + format: "%s", + formatArgs: `{"key1": b"xyz", "key5": null, "key2": duration("2h"), "key4": true, "key3": 2.71828}`, + locale: "nl_NL", + expectedOutput: `{"key1":b"xyz", "key2":duration("7200s"), "key3":2.718280, "key4":true, "key5":null}`, + expectedRuntimeCost: 42, + expectedEstimatedCost: checker.CostEstimate{Min: 42, Max: 42}, + }, + { + name: "map support (all key types)", + format: "map with multiple key types: %s", + formatArgs: `{1: "value1", uint(2): "value2", true: double("NaN")}`, + expectedOutput: `map with multiple key types: {1:"value1", 2:"value2", true:"NaN"}`, + expectedRuntimeCost: 46, + expectedEstimatedCost: checker.CostEstimate{Min: 46, Max: 46}, + }, + { + name: "boolean support for %s", + format: "true bool: %s, false bool: %s", + formatArgs: `true, false`, + expectedOutput: "true bool: true, false bool: false", + expectedRuntimeCost: 13, + expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13}, + }, + { + name: "dyntype support for string formatting clause", + format: "dynamic string: %s", + formatArgs: `dynStr`, + dynArgs: map[string]any{ + "dynStr": "a string", + }, + expectedOutput: "dynamic string: a string", + expectedRuntimeCost: 13, + expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13}, + }, + { + name: "dyntype support for numbers with string formatting clause", + format: "dynIntStr: %s dynDoubleStr: %s", + formatArgs: `dynIntStr, dynDoubleStr`, + dynArgs: map[string]any{ + "dynIntStr": 32, + "dynDoubleStr": 56.8, + }, + expectedOutput: "dynIntStr: 32 dynDoubleStr: 56.8", + expectedRuntimeCost: 15, + expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, + locale: "en_US", + }, + { + name: "dyntype support for integer formatting clause", + format: "dynamic int: %d", + formatArgs: `dynInt`, + dynArgs: map[string]any{ + "dynInt": 128, + }, + expectedOutput: "dynamic int: 128", + expectedRuntimeCost: 13, + expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13}, + }, + { + name: "dyntype support for integer formatting clause (unsigned)", + format: "dynamic unsigned int: %d", + formatArgs: `dynUnsignedInt`, + dynArgs: map[string]any{ + "dynUnsignedInt": uint64(256), + }, + expectedOutput: "dynamic unsigned int: 256", + expectedRuntimeCost: 14, + expectedEstimatedCost: checker.CostEstimate{Min: 14, Max: 14}, + }, + { + name: "dyntype support for hex formatting clause", + format: "dynamic hex int: %x", + formatArgs: `dynHexInt`, + dynArgs: map[string]any{ + "dynHexInt": 22, + }, + expectedOutput: "dynamic hex int: 16", + expectedRuntimeCost: 13, + expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13}, + }, + { + name: "dyntype support for hex formatting clause (uppercase)", + format: "dynamic hex int: %X (uppercase)", + formatArgs: `dynHexInt`, + dynArgs: map[string]any{ + "dynHexInt": 26, + }, + expectedOutput: "dynamic hex int: 1A (uppercase)", + expectedRuntimeCost: 15, + expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, + }, + { + name: "dyntype support for unsigned hex formatting clause", + format: "dynamic hex int: %x (unsigned)", + formatArgs: `dynUnsignedHexInt`, + dynArgs: map[string]any{ + "dynUnsignedHexInt": uint(500), + }, + expectedOutput: "dynamic hex int: 1f4 (unsigned)", + expectedRuntimeCost: 14, + expectedEstimatedCost: checker.CostEstimate{Min: 14, Max: 14}, + }, + { + name: "dyntype support for fixed-point formatting clause", + format: "dynamic double: %.3f", + formatArgs: `dynDouble`, + dynArgs: map[string]any{ + "dynDouble": 4.5, + }, + expectedOutput: "dynamic double: 4.500", + expectedRuntimeCost: 13, + expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13}, + locale: "en_US", + }, + { + name: "dyntype support for fixed-point formatting clause (comma separator locale)", + format: "dynamic double: %f", + formatArgs: `dynDouble`, + dynArgs: map[string]any{ + "dynDouble": 4.5, + }, + expectedOutput: "dynamic double: 4,500000", + expectedRuntimeCost: 13, + expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13}, + locale: "fr_FR", + }, + { + name: "dyntype support for scientific notation", + format: "(dyntype) e: %e", + formatArgs: "dynE", + dynArgs: map[string]any{ + "dynE": 2.71828, + }, + expectedOutput: "(dyntype) e: 2.718280\u202f\u00d7\u202f10\u2070\u2070", + expectedRuntimeCost: 13, + expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13}, + locale: "en_US", + }, + { + name: "dyntype NaN/infinity support for fixed-point", + format: "NaN: %f, infinity: %f", + formatArgs: `dynNaN, dynInf`, + dynArgs: map[string]any{ + "dynNaN": math.NaN(), + "dynInf": math.Inf(1), + }, + expectedOutput: "NaN: NaN, infinity: ∞", + expectedRuntimeCost: 15, + expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, + }, + { + name: "dyntype support for timestamp", + format: "dyntype timestamp: %s", + formatArgs: `dynTime`, + dynArgs: map[string]any{ + "dynTime": time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC), + }, + expectedOutput: "dyntype timestamp: 2009-11-10T23:00:00Z", + expectedRuntimeCost: 14, + expectedEstimatedCost: checker.CostEstimate{Min: 14, Max: 14}, + }, + { + name: "dyntype support for duration", + format: "dyntype duration: %s", + formatArgs: `dynDuration`, + dynArgs: map[string]any{ + "dynDuration": mustParseDuration("2h25m47s"), + }, + expectedOutput: "dyntype duration: 8747s", + expectedRuntimeCost: 13, + expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13}, + }, + { + name: "dyntype support for lists", + format: "dyntype list: %s", + formatArgs: `dynList`, + dynArgs: map[string]any{ + "dynList": []any{6, 4.2, "a string"}, + }, + expectedOutput: `dyntype list: [6, 4.200000, "a string"]`, + expectedRuntimeCost: 13, + expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13}, + }, + { + name: "dyntype support for maps", + format: "dyntype map: %s", + formatArgs: `dynMap`, + dynArgs: map[string]any{ + "dynMap": map[any]any{ + "strKey": "x", + true: 42, + int64(6): mustParseDuration("7m2s"), + }, + }, + expectedOutput: `dyntype map: {"strKey":"x", 6:duration("422s"), true:42}`, + expectedRuntimeCost: 13, + expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13}, + }, + { + name: "message field support", + format: "message field msg.single_int32: %d, msg.single_double: %.1f", + formatArgs: `msg.single_int32, msg.single_double`, + dynArgs: map[string]any{ + "msg": &proto3pb.TestAllTypes{ + SingleInt32: 2, + SingleDouble: 1.0, + }, + }, + locale: "en_US", + expectedOutput: `message field msg.single_int32: 2, msg.single_double: 1.0`, + }, + { + name: "unrecognized formatting clause", + format: "%a", + formatArgs: "1", + skipCompileCheck: true, + err: "could not parse formatting clause: unrecognized formatting clause \"a\"", + }, + { + name: "out of bounds arg index", + format: "%d %d %d", + formatArgs: "0, 1", + skipCompileCheck: true, + err: "index 2 out of range", + }, + { + name: "string substitution is not allowed with binary clause", + format: "string is %b", + formatArgs: `"abc"`, + skipCompileCheck: true, + err: "error during formatting: only integers and bools can be formatted as binary, was given string", + }, + { + name: "duration substitution not allowed with decimal clause", + format: "%d", + formatArgs: `duration("30m2s")`, + skipCompileCheck: true, + err: "error during formatting: decimal clause can only be used on integers, was given google.protobuf.Duration", + }, + { + name: "string substitution not allowed with octal clause", + format: "octal: %o", + formatArgs: `"a string"`, + skipCompileCheck: true, + err: "error during formatting: octal clause can only be used on integers, was given string", + }, + { + name: "double substitution not allowed with hex clause", + format: "double is %x", + formatArgs: "0.5", + skipCompileCheck: true, + err: "error during formatting: only integers, byte buffers, and strings can be formatted as hex, was given double", + }, + { + name: "uppercase not allowed for scientific clause", + format: "double is %E", + formatArgs: "0.5", + skipCompileCheck: true, + err: `could not parse formatting clause: unrecognized formatting clause "E"`, + }, + { + name: "object not allowed", + format: "object is %s", + formatArgs: `ext.TestAllTypes{PbVal: test.TestAllTypes{}}`, + skipCompileCheck: true, + err: "error during formatting: string clause can only be used on strings, bools, bytes, ints, doubles, maps, lists, types, durations, and timestamps, was given ext.TestAllTypes", + }, + { + name: "object inside list", + format: "%s", + formatArgs: "[1, 2, ext.TestAllTypes{PbVal: test.TestAllTypes{}}]", + skipCompileCheck: true, + err: "error during formatting: string clause can only be used on strings, bools, bytes, ints, doubles, maps, lists, types, durations, and timestamps, was given ext.TestAllTypes", + }, + { + name: "object inside map", + format: "%s", + formatArgs: `{1: "a", 2: ext.TestAllTypes{}}`, + skipCompileCheck: true, + err: "error during formatting: string clause can only be used on strings, bools, bytes, ints, doubles, maps, lists, types, durations, and timestamps, was given ext.TestAllTypes", + }, + { + name: "null not allowed for %d", + format: "null: %d", + formatArgs: "null", + skipCompileCheck: true, + err: "error during formatting: decimal clause can only be used on integers, was given null_type", + }, + { + name: "null not allowed for %e", + format: "null: %e", + formatArgs: "null", + skipCompileCheck: true, + err: "error during formatting: scientific clause can only be used on doubles, was given null_type", + }, + { + name: "null not allowed for %f", + format: "null: %f", + formatArgs: "null", + skipCompileCheck: true, + err: "error during formatting: fixed-point clause can only be used on doubles, was given null_type", + }, + { + name: "null not allowed for %x", + format: "null: %x", + formatArgs: "null", + skipCompileCheck: true, + err: "error during formatting: only integers, byte buffers, and strings can be formatted as hex, was given null_type", + }, + { + name: "null not allowed for %X", + format: "null: %X", + formatArgs: "null", + skipCompileCheck: true, + err: "error during formatting: only integers, byte buffers, and strings can be formatted as hex, was given null_type", + }, + { + name: "null not allowed for %b", + format: "null: %b", + formatArgs: "null", + skipCompileCheck: true, + err: "error during formatting: only integers and bools can be formatted as binary, was given null_type", + }, + { + name: "null not allowed for %o", + format: "null: %o", + formatArgs: "null", + skipCompileCheck: true, + err: "error during formatting: octal clause can only be used on integers, was given null_type", + }, + { + name: "compile-time cardinality check (too few for string)", + format: "%s %s", + formatArgs: `"abc"`, + err: "index 1 out of range", + }, + { + name: "compile-time cardinality check (too many for string)", + format: "%s %s", + formatArgs: `"abc", "def", "ghi"`, + err: "too many arguments supplied to string.format (expected 2, got 3)", + }, + { + name: "compile-time syntax check (unexpected end of string)", + format: "filler %", + formatArgs: "", + err: "unexpected end of string", + }, + { + name: "compile-time syntax check (unrecognized formatting clause)", + format: "%j", + // pass args here, otherwise the cardinality check will fail first + formatArgs: "123", + err: `could not parse formatting clause: unrecognized formatting clause "j"`, + }, + { + name: "compile-time %s check", + format: "object is %s", + formatArgs: `ext.TestAllTypes{PbVal: test.TestAllTypes{}}`, + err: "error during formatting: string clause can only be used on strings, bools, bytes, ints, doubles, maps, lists, types, durations, and timestamps", + }, + { + name: "compile-time check for objects inside list literal", + format: "list is %s", + formatArgs: `[1, 2, ext.TestAllTypes{PbVal: test.TestAllTypes{}}]`, + err: "error during formatting: string clause can only be used on strings, bools, bytes, ints, doubles, maps, lists, types, durations, and timestamps", + }, + { + name: "compile-time %d check", + format: "int is %d", + formatArgs: "5.2", + err: "error during formatting: decimal clause can only be used on integers", + }, + { + name: "compile-time %f check", + format: "double is %f", + formatArgs: "true", + err: "error during formatting: fixed-point clause can only be used on doubles", + }, + { + name: "compile-time precision syntax check", + format: "double is %.34", + formatArgs: "5.0", + err: "could not parse formatting clause: error while parsing precision: could not find end of precision specifier", + }, + { + name: "compile-time %e check", + format: "double is %e", + formatArgs: "true", + err: "error during formatting: scientific clause can only be used on doubles", + }, + { + name: "compile-time %b check", + format: "string is %b", + formatArgs: `"a string"`, + err: "error during formatting: only integers and bools can be formatted as binary", + }, + { + name: "compile-time %x check", + format: "%x is a double", + formatArgs: "2.5", + err: "error during formatting: only integers, byte buffers, and strings can be formatted as hex", + }, + { + name: "compile-time %X check", + format: "%X is a double", + formatArgs: "2.5", + err: "error during formatting: only integers, byte buffers, and strings can be formatted as hex", + }, + { + name: "compile-time %o check", + format: "an octal: %o", + formatArgs: "3.14", + err: "error during formatting: octal clause can only be used on integers", + }, + } + evalExpr := func(env *cel.Env, expr string, evalArgs any, expectedRuntimeCost uint64, expectedEstimatedCost checker.CostEstimate, t *testing.T) (ref.Val, error) { + t.Logf("evaluating expr: %s", expr) + parsedAst, issues := env.Parse(expr) + if issues.Err() != nil { + t.Fatalf("env.Parse(%v) failed: %v", expr, issues.Err()) + } + checkedAst, issues := env.Check(parsedAst) + if issues.Err() != nil { + return nil, issues.Err() + } + evalOpts := make([]cel.ProgramOption, 0) + costTracker := &noopCostEstimator{} + if expectedRuntimeCost != 0 { + evalOpts = append(evalOpts, cel.CostTracking(costTracker)) + } + program, err := env.Program(checkedAst, evalOpts...) + if err != nil { + return nil, err + } + + actualEstimatedCost, err := env.EstimateCost(checkedAst, costTracker) + if err != nil { + t.Fatal(err) + } + if expectedEstimatedCost.Min != 0 && expectedEstimatedCost.Max != 0 { + if actualEstimatedCost.Min != expectedEstimatedCost.Min && actualEstimatedCost.Max != expectedEstimatedCost.Max { + t.Fatalf("expected estimated cost range to be %v, was %v", expectedEstimatedCost, actualEstimatedCost) + } + } + + var out ref.Val + var details *cel.EvalDetails + if evalArgs != nil { + out, details, err = program.Eval(evalArgs) + } else { + out, details, err = program.Eval(cel.NoVars()) + } + + if expectedRuntimeCost != 0 { + if details == nil { + t.Fatal("no EvalDetails available when runtime cost was expected") + } + if *details.ActualCost() != expectedRuntimeCost { + t.Fatalf("expected runtime cost to be %d, was %d", expectedRuntimeCost, *details.ActualCost()) + } + if expectedEstimatedCost.Min != 0 && expectedEstimatedCost.Max != 0 { + if *details.ActualCost() < expectedEstimatedCost.Min || *details.ActualCost() > expectedEstimatedCost.Max { + t.Fatalf("runtime cost %d outside of expected estimated cost range %v", *details.ActualCost(), expectedEstimatedCost) + } + } + } + return out, err + } + buildVariables := func(vars map[string]any) []cel.EnvOption { + opts := make([]cel.EnvOption, len(vars)) + i := 0 + for name, value := range vars { + t := cel.DynType + switch v := value.(type) { + case proto.Message: + t = cel.ObjectType(string(v.ProtoReflect().Descriptor().FullName())) + case types.Bool: + t = cel.BoolType + case types.Bytes: + t = cel.BytesType + case types.Double: + t = cel.DoubleType + case types.Duration: + t = cel.DurationType + case types.Int: + t = cel.IntType + case types.Null: + t = cel.NullType + case types.String: + t = cel.StringType + case types.Timestamp: + t = cel.TimestampType + case types.Uint: + t = cel.UintType + } + opts[i] = cel.Variable(name, t) + i++ + } + return opts + } + buildOpts := func(skipCompileCheck bool, locale string, variables []cel.EnvOption) []cel.EnvOption { + opts := []cel.EnvOption{ + Strings(StringsLocale(locale), StringsValidateFormatCalls(!skipCompileCheck), StringsVersion(3)), + cel.Container("ext"), + cel.Abbrevs("google.expr.proto3.test"), + cel.Types(&proto3pb.TestAllTypes{}), + NativeTypes( + reflect.TypeOf(&TestNestedType{}), + reflect.ValueOf(&TestAllTypes{}), + ), + } + opts = append(opts, cel.ASTValidators(cel.ValidateHomogeneousAggregateLiterals())) + opts = append(opts, variables...) + return opts + } + runCase := func(format, formatArgs, locale string, dynArgs map[string]any, skipCompileCheck bool, expectedRuntimeCost uint64, expectedEstimatedCost checker.CostEstimate, t *testing.T) (ref.Val, error) { + env, err := cel.NewEnv(buildOpts(skipCompileCheck, locale, buildVariables(dynArgs))...) + if err != nil { + t.Fatalf("cel.NewEnv() failed: %v", err) + } + expr := fmt.Sprintf("%q.format([%s])", format, formatArgs) + if len(dynArgs) == 0 { + return evalExpr(env, expr, cel.NoVars(), expectedRuntimeCost, expectedEstimatedCost, t) + } + return evalExpr(env, expr, dynArgs, expectedRuntimeCost, expectedEstimatedCost, t) + } + checkCase := func(output ref.Val, expectedOutput string, err error, expectedErr string, t *testing.T) { + if err != nil { + if expectedErr != "" { + if !strings.Contains(err.Error(), expectedErr) { + t.Fatalf("expected %q as error message, got %q", expectedErr, err.Error()) + } + } else { + t.Fatalf("unexpected error: %s", err) + } + } else { + if output.Type() != types.StringType { + t.Fatalf("expected test expr to eval to string (got %s instead)", output.Type().TypeName()) + } else { + outputStr := output.Value().(string) + if outputStr != expectedOutput { + t.Errorf("expected %q as output, got %q", expectedOutput, outputStr) + } + } + } + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + out, err := runCase(tt.format, tt.formatArgs, tt.locale, tt.dynArgs, tt.skipCompileCheck, tt.expectedRuntimeCost, tt.expectedEstimatedCost, t) + checkCase(out, tt.expectedOutput, err, tt.err, t) + if tt.locale == "" { + // if the test has no locale specified, then that means it + // should have the same output regardless of locale + t.Run("no change on locale", func(t *testing.T) { + out, err := runCase(tt.format, tt.formatArgs, "da_DK", tt.dynArgs, tt.skipCompileCheck, tt.expectedRuntimeCost, tt.expectedEstimatedCost, t) + checkCase(out, tt.expectedOutput, err, tt.err, t) + }) + } + }) + } +} + +func TestStringFormatHeterogeneousLiterals(t *testing.T) { + tests := []struct { + expr string + out string + }{ + { + expr: `"list: %s".format([[[1, 2, [3.0, 4]]]])`, + out: `list: [[1, 2, [3.000000, 4]]]`, + }, + { + expr: `"list size: %d".format([[[1, 2, [3.0, 4]]].size()])`, + out: `list size: 1`, + }, + { + expr: `"list element: %s".format([[[1, 2, [3.0, 4]]][0]])`, + out: `list element: [1, 2, [3.000000, 4]]`, + }, + } + env, err := cel.NewEnv(Strings(StringsVersion(3)), cel.ASTValidators(cel.ValidateHomogeneousAggregateLiterals())) + if err != nil { + t.Fatalf("cel.NewEnv() failed: %v", err) + } + for _, tst := range tests { + tc := tst + t.Run(tc.expr, func(t *testing.T) { + ast, iss := env.Compile(tc.expr) + if iss.Err() != nil { + t.Fatalf("env.Compile(%q) failed: %v", tc.expr, iss.Err()) + } + prg, err := env.Program(ast) + if err != nil { + t.Fatalf("env.Program() failed: %v", err) + } + out, _, err := prg.Eval(cel.NoVars()) + if err != nil { + t.Fatalf("Eval() failed: %v", err) + } + if out.Value() != tc.out { + t.Errorf("Eval() got %v, wanted %v", out, tc.out) + } + }) + } +} + +func TestBadLocale(t *testing.T) { + _, err := cel.NewEnv(Strings(StringsLocale("bad-locale"), StringsVersion(3))) + if err != nil { + if err.Error() != "failed to parse locale: language: subtag \"locale\" is well-formed but unknown" { + t.Errorf("expected error messaged to be \"failed to parse locale: language: subtag \"locale\" is well-formed but unknown\", got %q", err) + } + } else { + t.Error("expected NewEnv to fail during locale parsing") + } +} + +func TestLiteralOutput(t *testing.T) { + tests := []struct { + name string + formatLiteral string + expectedType string + }{ + { + name: "map literal support", + formatLiteral: `{"key1": b"xyz", false: [11, 12, 13, timestamp("2019-10-12T07:20:50.52Z")], 42: {uint(64): 2.7}, "key5": type(int), "key2": duration("2h"), "key4": true, "key3": 2.71828, "null": null}`, + expectedType: `map`, + }, + { + name: "list literal support", + formatLiteral: `["abc", 3.14, uint(32), b"def", null, type(string), duration("7m"), [9, 8, 7, 6], timestamp("2023-02-03T23:31:20Z")]`, + expectedType: `list`, + }, + } + for _, tt := range tests { + parseAndEval := func(expr string, t *testing.T) (ref.Val, error) { + env, err := cel.NewEnv(Strings(StringsVersion(3))) + if err != nil { + t.Fatalf("cel.NewEnv(Strings(StringsVersion(3))) failed: %v", err) + } + parsedAst, issues := env.Parse(expr) + if issues.Err() != nil { + t.Fatalf("env.Parse(%v) failed: %v", expr, issues.Err()) + } + checkedAst, issues := env.Check(parsedAst) + if issues.Err() != nil { + t.Fatalf("env.Check(%v) failed: %v", expr, issues.Err()) + } + program, err := env.Program(checkedAst) + if err != nil { + t.Fatal(err) + } + out, _, err := program.Eval(cel.NoVars()) + return out, err + } + t.Run(tt.name, func(t *testing.T) { + expr := fmt.Sprintf(`"%%s".format([%s])`, tt.formatLiteral) + literalVal, err := parseAndEval(expr, t) + if err != nil { + t.Fatalf("program.Eval failed: %v", err) + } + out, err := parseAndEval(literalVal.Value().(string), t) + if err != nil { + t.Fatalf("literal evaluation failed: %v", err) + } + if out.Type().TypeName() != tt.expectedType { + t.Errorf("expected literal to evaluate to type %s, got %s", tt.expectedType, out.Type().TypeName()) + } + equivalentVal, err := parseAndEval(literalVal.Value().(string)+" == "+tt.formatLiteral, t) + if err != nil { + t.Fatalf("equality evaluation failed: %v:", err) + } + if equivalentVal.Type().TypeName() != "bool" { + t.Errorf("expected equality expression to evaluation to type bool, got %s", equivalentVal.Type().TypeName()) + } + equivalent := equivalentVal.Value().(bool) + if !equivalent { + t.Errorf("%q (observed) and %q (expected) not considered equivalent", literalVal.Value().(string), tt.formatLiteral) + } + }) + } +} diff --git a/ext/formatting_v2.go b/ext/formatting_v2.go new file mode 100644 index 000000000..ca8efbc4e --- /dev/null +++ b/ext/formatting_v2.go @@ -0,0 +1,788 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ext + +import ( + "errors" + "fmt" + "math" + "sort" + "strconv" + "strings" + "time" + "unicode" + + "github.com/google/cel-go/cel" + "github.com/google/cel-go/common/ast" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" + "github.com/google/cel-go/common/types/traits" +) + +type clauseImplV2 func(ref.Val) (string, error) + +type appendingFormatterV2 struct { + buf []byte +} + +type formattedMapEntryV2 struct { + key string + val string +} + +func (af *appendingFormatterV2) format(arg ref.Val) error { + switch arg.Type() { + case types.BoolType: + argBool, ok := arg.Value().(bool) + if !ok { + return fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.BoolType) + } + af.buf = strconv.AppendBool(af.buf, argBool) + return nil + case types.IntType: + argInt, ok := arg.Value().(int64) + if !ok { + return fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.IntType) + } + af.buf = strconv.AppendInt(af.buf, argInt, 10) + return nil + case types.UintType: + argUint, ok := arg.Value().(uint64) + if !ok { + return fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.UintType) + } + af.buf = strconv.AppendUint(af.buf, argUint, 10) + return nil + case types.DoubleType: + argDbl, ok := arg.Value().(float64) + if !ok { + return fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.DoubleType) + } + if math.IsNaN(argDbl) { + af.buf = append(af.buf, "NaN"...) + return nil + } + if math.IsInf(argDbl, -1) { + af.buf = append(af.buf, "-Infinity"...) + return nil + } + if math.IsInf(argDbl, 1) { + af.buf = append(af.buf, "Infinity"...) + return nil + } + af.buf = strconv.AppendFloat(af.buf, argDbl, 'f', -1, 64) + return nil + case types.BytesType: + argBytes, ok := arg.Value().([]byte) + if !ok { + return fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.BytesType) + } + af.buf = append(af.buf, argBytes...) + return nil + case types.StringType: + argStr, ok := arg.Value().(string) + if !ok { + return fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.StringType) + } + af.buf = append(af.buf, argStr...) + return nil + case types.DurationType: + argDur, ok := arg.Value().(time.Duration) + if !ok { + return fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.DurationType) + } + af.buf = strconv.AppendFloat(af.buf, argDur.Seconds(), 'f', -1, 64) + af.buf = append(af.buf, "s"...) + return nil + case types.TimestampType: + argTime, ok := arg.Value().(time.Time) + if !ok { + return fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.TimestampType) + } + af.buf = argTime.UTC().AppendFormat(af.buf, time.RFC3339Nano) + return nil + case types.NullType: + af.buf = append(af.buf, "null"...) + return nil + case types.TypeType: + argType, ok := arg.Value().(string) + if !ok { + return fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.TypeType) + } + af.buf = append(af.buf, argType...) + return nil + case types.ListType: + argList, ok := arg.(traits.Lister) + if !ok { + return fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.ListType) + } + argIter := argList.Iterator() + af.buf = append(af.buf, "["...) + if argIter.HasNext() == types.True { + if err := af.format(argIter.Next()); err != nil { + return err + } + for argIter.HasNext() == types.True { + af.buf = append(af.buf, ", "...) + if err := af.format(argIter.Next()); err != nil { + return err + } + } + } + af.buf = append(af.buf, "]"...) + return nil + case types.MapType: + argMap, ok := arg.(traits.Mapper) + if !ok { + return fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.MapType) + } + argIter := argMap.Iterator() + ents := []formattedMapEntryV2{} + for argIter.HasNext() == types.True { + key := argIter.Next() + val, ok := argMap.Find(key) + if !ok { + return fmt.Errorf("key missing from map: '%s'", key) + } + keyStr, err := formatStringV2(key) + if err != nil { + return err + } + valStr, err := formatStringV2(val) + if err != nil { + return err + } + ents = append(ents, formattedMapEntryV2{keyStr, valStr}) + } + sort.SliceStable(ents, func(x, y int) bool { + return ents[x].key < ents[y].key + }) + af.buf = append(af.buf, "{"...) + for i, e := range ents { + if i > 0 { + af.buf = append(af.buf, ", "...) + } + af.buf = append(af.buf, e.key...) + af.buf = append(af.buf, ": "...) + af.buf = append(af.buf, e.val...) + } + af.buf = append(af.buf, "}"...) + return nil + default: + return stringFormatErrorV2(runtimeID, arg.Type().TypeName()) + } +} + +func formatStringV2(arg ref.Val) (string, error) { + var fmter appendingFormatterV2 + if err := fmter.format(arg); err != nil { + return "", err + } + return string(fmter.buf), nil +} + +type stringFormatterV2 struct{} + +// String implements formatStringInterpolatorV2.String. +func (c *stringFormatterV2) String(arg ref.Val) (string, error) { + return formatStringV2(arg) +} + +// Decimal implements formatStringInterpolatorV2.Decimal. +func (c *stringFormatterV2) Decimal(arg ref.Val) (string, error) { + switch arg.Type() { + case types.IntType: + argInt, ok := arg.Value().(int64) + if !ok { + return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.IntType) + } + return strconv.FormatInt(argInt, 10), nil + case types.UintType: + argUint, ok := arg.Value().(uint64) + if !ok { + return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.UintType) + } + return strconv.FormatUint(argUint, 10), nil + case types.DoubleType: + argDbl, ok := arg.Value().(float64) + if !ok { + return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.DoubleType) + } + if math.IsNaN(argDbl) { + return "NaN", nil + } + if math.IsInf(argDbl, -1) { + return "-Infinity", nil + } + if math.IsInf(argDbl, 1) { + return "Infinity", nil + } + return strconv.FormatFloat(argDbl, 'f', -1, 64), nil + default: + return "", decimalFormatErrorV2(runtimeID, arg.Type().TypeName()) + } +} + +// Fixed implements formatStringInterpolatorV2.Fixed. +func (c *stringFormatterV2) Fixed(precision int) func(ref.Val) (string, error) { + return func(arg ref.Val) (string, error) { + fmtStr := fmt.Sprintf("%%.%df", precision) + switch arg.Type() { + case types.IntType: + argInt, ok := arg.Value().(int64) + if !ok { + return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.IntType) + } + return fmt.Sprintf(fmtStr, argInt), nil + case types.UintType: + argUint, ok := arg.Value().(uint64) + if !ok { + return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.UintType) + } + return fmt.Sprintf(fmtStr, argUint), nil + case types.DoubleType: + argDbl, ok := arg.Value().(float64) + if !ok { + return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.DoubleType) + } + if math.IsNaN(argDbl) { + return "NaN", nil + } + if math.IsInf(argDbl, -1) { + return "-Infinity", nil + } + if math.IsInf(argDbl, 1) { + return "Infinity", nil + } + return fmt.Sprintf(fmtStr, argDbl), nil + default: + return "", fixedPointFormatErrorV2(runtimeID, arg.Type().TypeName()) + } + } +} + +// Scientific implements formatStringInterpolatorV2.Scientific. +func (c *stringFormatterV2) Scientific(precision int) func(ref.Val) (string, error) { + return func(arg ref.Val) (string, error) { + fmtStr := fmt.Sprintf("%%1.%de", precision) + switch arg.Type() { + case types.IntType: + argInt, ok := arg.Value().(int64) + if !ok { + return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.IntType) + } + return fmt.Sprintf(fmtStr, argInt), nil + case types.UintType: + argUint, ok := arg.Value().(uint64) + if !ok { + return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.UintType) + } + return fmt.Sprintf(fmtStr, argUint), nil + case types.DoubleType: + argDbl, ok := arg.Value().(float64) + if !ok { + return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.DoubleType) + } + if math.IsNaN(argDbl) { + return "NaN", nil + } + if math.IsInf(argDbl, -1) { + return "-Infinity", nil + } + if math.IsInf(argDbl, 1) { + return "Infinity", nil + } + return fmt.Sprintf(fmtStr, argDbl), nil + default: + return "", scientificFormatErrorV2(runtimeID, arg.Type().TypeName()) + } + } +} + +// Binary implements formatStringInterpolatorV2.Binary. +func (c *stringFormatterV2) Binary(arg ref.Val) (string, error) { + switch arg.Type() { + case types.BoolType: + argBool, ok := arg.Value().(bool) + if !ok { + return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.BoolType) + } + if argBool { + return "1", nil + } + return "0", nil + case types.IntType: + argInt, ok := arg.Value().(int64) + if !ok { + return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.IntType) + } + return strconv.FormatInt(argInt, 2), nil + case types.UintType: + argUint, ok := arg.Value().(uint64) + if !ok { + return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.UintType) + } + return strconv.FormatUint(argUint, 2), nil + default: + return "", binaryFormatErrorV2(runtimeID, arg.Type().TypeName()) + } +} + +// Hex implements formatStringInterpolatorV2.Hex. +func (c *stringFormatterV2) Hex(useUpper bool) func(ref.Val) (string, error) { + return func(arg ref.Val) (string, error) { + var fmtStr string + if useUpper { + fmtStr = "%X" + } else { + fmtStr = "%x" + } + switch arg.Type() { + case types.IntType: + argInt, ok := arg.Value().(int64) + if !ok { + return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.IntType) + } + return fmt.Sprintf(fmtStr, argInt), nil + case types.UintType: + argUint, ok := arg.Value().(uint64) + if !ok { + return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.UintType) + } + return fmt.Sprintf(fmtStr, argUint), nil + case types.StringType: + argStr, ok := arg.Value().(string) + if !ok { + return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.StringType) + } + return fmt.Sprintf(fmtStr, argStr), nil + case types.BytesType: + argBytes, ok := arg.Value().([]byte) + if !ok { + return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.BytesType) + } + return fmt.Sprintf(fmtStr, argBytes), nil + default: + return "", hexFormatErrorV2(runtimeID, arg.Type().TypeName()) + } + } +} + +// Octal implements formatStringInterpolatorV2.Octal. +func (c *stringFormatterV2) Octal(arg ref.Val) (string, error) { + switch arg.Type() { + case types.IntType: + argInt, ok := arg.Value().(int64) + if !ok { + return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.IntType) + } + return strconv.FormatInt(argInt, 8), nil + case types.UintType: + argUint, ok := arg.Value().(uint64) + if !ok { + return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.UintType) + } + return strconv.FormatUint(argUint, 8), nil + default: + return "", octalFormatErrorV2(runtimeID, arg.Type().TypeName()) + } +} + +// stringFormatValidatorV2 implements the cel.ASTValidator interface allowing for static validation +// of string.format calls. +type stringFormatValidatorV2 struct{} + +// Name returns the name of the validator. +func (stringFormatValidatorV2) Name() string { + return "cel.validator.string_format" +} + +// Configure implements the ASTValidatorConfigurer interface and augments the list of functions to skip +// during homogeneous aggregate literal type-checks. +func (stringFormatValidatorV2) Configure(config cel.MutableValidatorConfig) error { + functions := config.GetOrDefault(cel.HomogeneousAggregateLiteralExemptFunctions, []string{}).([]string) + functions = append(functions, "format") + return config.Set(cel.HomogeneousAggregateLiteralExemptFunctions, functions) +} + +// Validate parses all literal format strings and type checks the format clause against the argument +// at the corresponding ordinal within the list literal argument to the function, if one is specified. +func (stringFormatValidatorV2) Validate(env *cel.Env, _ cel.ValidatorConfig, a *ast.AST, iss *cel.Issues) { + root := ast.NavigateAST(a) + formatCallExprs := ast.MatchDescendants(root, matchConstantFormatStringWithListLiteralArgs(a)) + for _, e := range formatCallExprs { + call := e.AsCall() + formatStr := call.Target().AsLiteral().Value().(string) + args := call.Args()[0].AsList().Elements() + formatCheck := &stringFormatCheckerV2{ + args: args, + ast: a, + } + // use a placeholder locale, since locale doesn't affect syntax + _, err := parseFormatStringV2(formatStr, formatCheck, formatCheck) + if err != nil { + iss.ReportErrorAtID(getErrorExprID(e.ID(), err), "%v", err) + continue + } + seenArgs := formatCheck.argsRequested + if len(args) > seenArgs { + iss.ReportErrorAtID(e.ID(), + "too many arguments supplied to string.format (expected %d, got %d)", seenArgs, len(args)) + } + } +} + +// stringFormatCheckerV2 implements the formatStringInterpolater interface +type stringFormatCheckerV2 struct { + args []ast.Expr + argsRequested int + currArgIndex int64 + ast *ast.AST +} + +// String implements formatStringInterpolatorV2.String. +func (c *stringFormatCheckerV2) String(arg ref.Val) (string, error) { + formatArg := c.args[c.currArgIndex] + valid, badID := c.verifyString(formatArg) + if !valid { + return "", stringFormatErrorV2(badID, c.typeOf(badID).TypeName()) + } + return "", nil +} + +// Decimal implements formatStringInterpolatorV2.Decimal. +func (c *stringFormatCheckerV2) Decimal(arg ref.Val) (string, error) { + id := c.args[c.currArgIndex].ID() + valid := c.verifyTypeOneOf(id, types.IntType, types.UintType, types.DoubleType) + if !valid { + return "", decimalFormatErrorV2(id, c.typeOf(id).TypeName()) + } + return "", nil +} + +// Fixed implements formatStringInterpolatorV2.Fixed. +func (c *stringFormatCheckerV2) Fixed(precision int) func(ref.Val) (string, error) { + return func(arg ref.Val) (string, error) { + id := c.args[c.currArgIndex].ID() + valid := c.verifyTypeOneOf(id, types.IntType, types.UintType, types.DoubleType) + if !valid { + return "", fixedPointFormatErrorV2(id, c.typeOf(id).TypeName()) + } + return "", nil + } +} + +// Scientific implements formatStringInterpolatorV2.Scientific. +func (c *stringFormatCheckerV2) Scientific(precision int) func(ref.Val) (string, error) { + return func(arg ref.Val) (string, error) { + id := c.args[c.currArgIndex].ID() + valid := c.verifyTypeOneOf(id, types.IntType, types.UintType, types.DoubleType) + if !valid { + return "", scientificFormatErrorV2(id, c.typeOf(id).TypeName()) + } + return "", nil + } +} + +// Binary implements formatStringInterpolatorV2.Binary. +func (c *stringFormatCheckerV2) Binary(arg ref.Val) (string, error) { + id := c.args[c.currArgIndex].ID() + valid := c.verifyTypeOneOf(id, types.BoolType, types.IntType, types.UintType) + if !valid { + return "", binaryFormatErrorV2(id, c.typeOf(id).TypeName()) + } + return "", nil +} + +// Hex implements formatStringInterpolatorV2.Hex. +func (c *stringFormatCheckerV2) Hex(useUpper bool) func(ref.Val) (string, error) { + return func(arg ref.Val) (string, error) { + id := c.args[c.currArgIndex].ID() + valid := c.verifyTypeOneOf(id, types.IntType, types.UintType, types.StringType, types.BytesType) + if !valid { + return "", hexFormatErrorV2(id, c.typeOf(id).TypeName()) + } + return "", nil + } +} + +// Octal implements formatStringInterpolatorV2.Octal. +func (c *stringFormatCheckerV2) Octal(arg ref.Val) (string, error) { + id := c.args[c.currArgIndex].ID() + valid := c.verifyTypeOneOf(id, types.IntType, types.UintType) + if !valid { + return "", octalFormatErrorV2(id, c.typeOf(id).TypeName()) + } + return "", nil +} + +// Arg implements formatListArgs.Arg. +func (c *stringFormatCheckerV2) Arg(index int64) (ref.Val, error) { + c.argsRequested++ + c.currArgIndex = index + // return a dummy value - this is immediately passed to back to us + // through one of the FormatCallback functions, so anything will do + return types.Int(0), nil +} + +// Size implements formatListArgs.Size. +func (c *stringFormatCheckerV2) Size() int64 { + return int64(len(c.args)) +} + +func (c *stringFormatCheckerV2) typeOf(id int64) *cel.Type { + return c.ast.GetType(id) +} + +func (c *stringFormatCheckerV2) verifyTypeOneOf(id int64, validTypes ...*cel.Type) bool { + t := c.typeOf(id) + if t == cel.DynType { + return true + } + for _, vt := range validTypes { + // Only check runtime type compatibility without delving deeper into parameterized types + if t.Kind() == vt.Kind() { + return true + } + } + return false +} + +func (c *stringFormatCheckerV2) verifyString(sub ast.Expr) (bool, int64) { + paramA := cel.TypeParamType("A") + paramB := cel.TypeParamType("B") + subVerified := c.verifyTypeOneOf(sub.ID(), + cel.ListType(paramA), cel.MapType(paramA, paramB), + cel.IntType, cel.UintType, cel.DoubleType, cel.BoolType, cel.StringType, + cel.TimestampType, cel.BytesType, cel.DurationType, cel.TypeType, cel.NullType) + if !subVerified { + return false, sub.ID() + } + switch sub.Kind() { + case ast.ListKind: + for _, e := range sub.AsList().Elements() { + // recursively verify if we're dealing with a list/map + verified, id := c.verifyString(e) + if !verified { + return false, id + } + } + return true, sub.ID() + case ast.MapKind: + for _, e := range sub.AsMap().Entries() { + // recursively verify if we're dealing with a list/map + entry := e.AsMapEntry() + verified, id := c.verifyString(entry.Key()) + if !verified { + return false, id + } + verified, id = c.verifyString(entry.Value()) + if !verified { + return false, id + } + } + return true, sub.ID() + default: + return true, sub.ID() + } +} + +// helper routines for reporting common errors during string formatting static validation and +// runtime execution. + +func binaryFormatErrorV2(id int64, badType string) error { + return newFormatError(id, "only ints, uints, and bools can be formatted as binary, was given %s", badType) +} + +func decimalFormatErrorV2(id int64, badType string) error { + return newFormatError(id, "decimal clause can only be used on ints, uints, and doubles, was given %s", badType) +} + +func fixedPointFormatErrorV2(id int64, badType string) error { + return newFormatError(id, "fixed-point clause can only be used on ints, uints, and doubles, was given %s", badType) +} + +func hexFormatErrorV2(id int64, badType string) error { + return newFormatError(id, "only ints, uints, bytes, and strings can be formatted as hex, was given %s", badType) +} + +func octalFormatErrorV2(id int64, badType string) error { + return newFormatError(id, "octal clause can only be used on ints and uints, was given %s", badType) +} + +func scientificFormatErrorV2(id int64, badType string) error { + return newFormatError(id, "scientific clause can only be used on ints, uints, and doubles, was given %s", badType) +} + +func stringFormatErrorV2(id int64, badType string) error { + return newFormatError(id, "string clause can only be used on strings, bools, bytes, ints, doubles, maps, lists, types, durations, and timestamps, was given %s", badType) +} + +// formatStringInterpolatorV2 is an interface that allows user-defined behavior +// for formatting clause implementations, as well as argument retrieval. +// Each function is expected to support the appropriate types as laid out in +// the string.format documentation, and to return an error if given an inappropriate type. +type formatStringInterpolatorV2 interface { + // String takes a ref.Val and a string representing the current locale identifier + // and returns the Val formatted as a string, or an error if one occurred. + String(ref.Val) (string, error) + + // Decimal takes a ref.Val and a string representing the current locale identifier + // and returns the Val formatted as a decimal integer, or an error if one occurred. + Decimal(ref.Val) (string, error) + + // Fixed takes an int pointer representing precision (or nil if none was given) and + // returns a function operating in a similar manner to String and Decimal, taking a + // ref.Val and locale and returning the appropriate string. A closure is returned + // so precision can be set without needing an additional function call/configuration. + Fixed(int) func(ref.Val) (string, error) + + // Scientific functions identically to Fixed, except the string returned from the closure + // is expected to be in scientific notation. + Scientific(int) func(ref.Val) (string, error) + + // Binary takes a ref.Val and a string representing the current locale identifier + // and returns the Val formatted as a binary integer, or an error if one occurred. + Binary(ref.Val) (string, error) + + // Hex takes a boolean that, if true, indicates the hex string output by the returned + // closure should use uppercase letters for A-F. + Hex(bool) func(ref.Val) (string, error) + + // Octal takes a ref.Val and a string representing the current locale identifier and + // returns the Val formatted in octal, or an error if one occurred. + Octal(ref.Val) (string, error) +} + +// parseFormatString formats a string according to the string.format syntax, taking the clause implementations +// from the provided FormatCallback and the args from the given FormatList. +func parseFormatStringV2(formatStr string, callback formatStringInterpolatorV2, list formatListArgs) (string, error) { + i := 0 + argIndex := 0 + var builtStr strings.Builder + for i < len(formatStr) { + if formatStr[i] == '%' { + if i+1 < len(formatStr) && formatStr[i+1] == '%' { + err := builtStr.WriteByte('%') + if err != nil { + return "", fmt.Errorf("error writing format string: %w", err) + } + i += 2 + continue + } else { + argAny, err := list.Arg(int64(argIndex)) + if err != nil { + return "", err + } + if i+1 >= len(formatStr) { + return "", errors.New("unexpected end of string") + } + if int64(argIndex) >= list.Size() { + return "", fmt.Errorf("index %d out of range", argIndex) + } + numRead, val, refErr := parseAndFormatClauseV2(formatStr[i:], argAny, callback, list) + if refErr != nil { + return "", refErr + } + _, err = builtStr.WriteString(val) + if err != nil { + return "", fmt.Errorf("error writing format string: %w", err) + } + i += numRead + argIndex++ + } + } else { + err := builtStr.WriteByte(formatStr[i]) + if err != nil { + return "", fmt.Errorf("error writing format string: %w", err) + } + i++ + } + } + return builtStr.String(), nil +} + +// parseAndFormatClause parses the format clause at the start of the given string with val, and returns +// how many characters were consumed and the substituted string form of val, or an error if one occurred. +func parseAndFormatClauseV2(formatStr string, val ref.Val, callback formatStringInterpolatorV2, list formatListArgs) (int, string, error) { + i := 1 + read, formatter, err := parseFormattingClauseV2(formatStr[i:], callback) + i += read + if err != nil { + return -1, "", newParseFormatError("could not parse formatting clause", err) + } + + valStr, err := formatter(val) + if err != nil { + return -1, "", newParseFormatError("error during formatting", err) + } + return i, valStr, nil +} + +func parseFormattingClauseV2(formatStr string, callback formatStringInterpolatorV2) (int, clauseImplV2, error) { + i := 0 + read, precision, err := parsePrecisionV2(formatStr[i:]) + i += read + if err != nil { + return -1, nil, fmt.Errorf("error while parsing precision: %w", err) + } + r := rune(formatStr[i]) + i++ + switch r { + case 's': + return i, callback.String, nil + case 'd': + return i, callback.Decimal, nil + case 'f': + return i, callback.Fixed(precision), nil + case 'e': + return i, callback.Scientific(precision), nil + case 'b': + return i, callback.Binary, nil + case 'x', 'X': + return i, callback.Hex(unicode.IsUpper(r)), nil + case 'o': + return i, callback.Octal, nil + default: + return -1, nil, fmt.Errorf("unrecognized formatting clause \"%c\"", r) + } +} + +func parsePrecisionV2(formatStr string) (int, int, error) { + i := 0 + if formatStr[i] != '.' { + return i, defaultPrecision, nil + } + i++ + var buffer strings.Builder + for { + if i >= len(formatStr) { + return -1, -1, errors.New("could not find end of precision specifier") + } + if !isASCIIDigit(rune(formatStr[i])) { + break + } + buffer.WriteByte(formatStr[i]) + i++ + } + precision, err := strconv.Atoi(buffer.String()) + if err != nil { + return -1, -1, fmt.Errorf("error while converting precision to integer: %w", err) + } + if precision < 0 { + return -1, -1, fmt.Errorf("negative precision: %d", precision) + } + return i, precision, nil +} diff --git a/ext/formatting_v2_test.go b/ext/formatting_v2_test.go new file mode 100644 index 000000000..d574e62de --- /dev/null +++ b/ext/formatting_v2_test.go @@ -0,0 +1,1015 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ext + +import ( + "fmt" + "math" + "reflect" + "strings" + "testing" + "time" + + "google.golang.org/protobuf/proto" + + "github.com/google/cel-go/cel" + "github.com/google/cel-go/checker" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" + + proto3pb "github.com/google/cel-go/test/proto3pb" +) + +func TestStringsWithExtensionV2(t *testing.T) { + env, err := cel.NewEnv(Strings()) + if err != nil { + t.Fatalf("cel.NewEnv(Strings()) failed: %v", err) + } + _, err = env.Extend(Strings()) + if err != nil { + t.Fatalf("env.Extend(Strings()) failed: %v", err) + } +} + +func TestStringFormatV2(t *testing.T) { + tests := []struct { + name string + format string + dynArgs map[string]any + formatArgs string + err string + expectedOutput string + expectedRuntimeCost uint64 + expectedEstimatedCost checker.CostEstimate + skipCompileCheck bool + }{ + { + name: "no-op", + format: "no substitution", + expectedOutput: "no substitution", + expectedRuntimeCost: 12, + expectedEstimatedCost: checker.CostEstimate{Min: 12, Max: 12}, + }, + + { + name: "mid-string substitution", + format: "str is %s and some more", + formatArgs: `"filler"`, + expectedOutput: "str is filler and some more", + expectedRuntimeCost: 13, + expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13}, + }, + { + name: "percent escaping", + format: "%% and also %%", + expectedOutput: "% and also %", + expectedRuntimeCost: 12, + expectedEstimatedCost: checker.CostEstimate{Min: 12, Max: 12}, + }, + { + name: "substution inside escaped percent signs", + format: "%%%s%%", + formatArgs: `"text"`, + expectedOutput: "%text%", + expectedRuntimeCost: 11, + expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, + }, + { + name: "substitution with one escaped percent sign on the right", + format: "%s%%", + formatArgs: `"percent on the right"`, + expectedOutput: "percent on the right%", + expectedRuntimeCost: 11, + expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, + }, + { + name: "substitution with one escaped percent sign on the left", + format: "%%%s", + formatArgs: `"percent on the left"`, + expectedOutput: "%percent on the left", + expectedRuntimeCost: 11, + expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, + }, + { + name: "multiple substitutions", + format: "%d %d %d, %s %s %s, %d %d %d, %s %s %s", + formatArgs: `1, 2, 3, "A", "B", "C", 4, 5, 6, "D", "E", "F"`, + expectedOutput: "1 2 3, A B C, 4 5 6, D E F", + expectedRuntimeCost: 14, + expectedEstimatedCost: checker.CostEstimate{Min: 14, Max: 14}, + }, + { + name: "percent sign escape sequence support", + format: "\u0025\u0025escaped \u0025s\u0025\u0025", + formatArgs: `"percent"`, + expectedOutput: "%escaped percent%", + expectedRuntimeCost: 12, + expectedEstimatedCost: checker.CostEstimate{Min: 12, Max: 12}, + }, + { + name: "fixed point formatting clause", + format: "%.3f", + formatArgs: "1.2345", + expectedOutput: "1.234", + expectedRuntimeCost: 11, + expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, + }, + { + name: "binary formatting clause", + format: "this is 5 in binary: %b", + formatArgs: "5", + expectedOutput: "this is 5 in binary: 101", + expectedRuntimeCost: 13, + expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13}, + }, + { + name: "negative binary formatting clause", + format: "this is -5 in binary: %b", + formatArgs: "-5", + expectedOutput: "this is -5 in binary: -101", + expectedRuntimeCost: 13, + expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13}, + }, + { + name: "uint support for binary formatting", + format: "unsigned 64 in binary: %b", + formatArgs: "uint(64)", + expectedOutput: "unsigned 64 in binary: 1000000", + expectedRuntimeCost: 14, + expectedEstimatedCost: checker.CostEstimate{Min: 14, Max: 14}, + }, + { + name: "bool support for binary formatting", + format: "bit set from bool: %b", + formatArgs: "true", + expectedOutput: "bit set from bool: 1", + expectedRuntimeCost: 13, + expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13}, + }, + { + name: "octal formatting clause", + format: "%o", + formatArgs: "11", + expectedOutput: "13", + expectedRuntimeCost: 11, + expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, + }, + { + name: "negative octal formatting clause", + format: "%o", + formatArgs: "-11", + expectedOutput: "-13", + expectedRuntimeCost: 11, + expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, + }, + { + name: "uint support for octal formatting clause", + format: "this is an unsigned octal: %o", + formatArgs: "uint(65535)", + expectedOutput: "this is an unsigned octal: 177777", + expectedRuntimeCost: 14, + expectedEstimatedCost: checker.CostEstimate{Min: 14, Max: 14}, + }, + { + name: "lowercase hexadecimal formatting clause", + format: "%x is 30 in hexadecimal", + formatArgs: "30", + expectedOutput: "1e is 30 in hexadecimal", + expectedRuntimeCost: 13, + expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13}, + }, + { + name: "uppercase hexadecimal formatting clause", + format: "%X is 20 in hexadecimal", + formatArgs: "30", + expectedOutput: "1E is 20 in hexadecimal", + expectedRuntimeCost: 13, + expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13}, + }, + { + name: "negative hexadecimal formatting clause", + format: "%x is -30 in hexadecimal", + formatArgs: "-30", + expectedOutput: "-1e is -30 in hexadecimal", + expectedRuntimeCost: 13, + expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13}, + }, + { + name: "unsigned support for hexadecimal formatting clause", + format: "%X is 6000 in hexadecimal", + formatArgs: "uint(6000)", + expectedOutput: "1770 is 6000 in hexadecimal", + expectedRuntimeCost: 14, + expectedEstimatedCost: checker.CostEstimate{Min: 14, Max: 14}, + }, + { + name: "string support with hexadecimal formatting clause", + format: "%x", + formatArgs: `"Hello world!"`, + expectedOutput: "48656c6c6f20776f726c6421", + expectedRuntimeCost: 11, + expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, + }, + { + name: "string support with uppercase hexadecimal formatting clause", + format: "%X", + formatArgs: `"Hello world!"`, + expectedOutput: "48656C6C6F20776F726C6421", + expectedRuntimeCost: 11, + expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, + }, + { + name: "byte support with hexadecimal formatting clause", + format: "%x", + formatArgs: `b"byte string"`, + expectedOutput: "6279746520737472696e67", + expectedRuntimeCost: 11, + expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, + }, + { + name: "byte support with hexadecimal formatting clause leading zero", + format: "%x", + formatArgs: `b"\x00\x00byte string\x00"`, + expectedOutput: "00006279746520737472696e6700", + expectedRuntimeCost: 11, + expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, + }, + { + name: "byte support with uppercase hexadecimal formatting clause", + format: "%X", + formatArgs: `b"byte string"`, + expectedOutput: "6279746520737472696E67", + expectedRuntimeCost: 11, + expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, + }, + { + name: "scientific notation formatting clause", + format: "%.6e", + formatArgs: "1052.032911275", + expectedOutput: "1.052033e+03", + expectedRuntimeCost: 11, + expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, + }, + { + name: "default precision for fixed-point clause", + format: "%f", + formatArgs: "2.71828", + expectedOutput: "2.718280", + expectedRuntimeCost: 11, + expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, + }, + { + name: "default precision for scientific notation", + format: "%e", + formatArgs: "2.71828", + expectedOutput: "2.718280e+00", + expectedRuntimeCost: 11, + expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, + }, + { + name: "default precision for string", + format: "%s", + formatArgs: "2.71", + expectedOutput: "2.71", + expectedRuntimeCost: 11, + expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, + }, + { + name: "default list precision for string", + format: "%s", + formatArgs: "[2.71]", + expectedOutput: "[2.71]", + expectedRuntimeCost: 21, + expectedEstimatedCost: checker.CostEstimate{Min: 21, Max: 21}, + }, + { + name: "default format for string", + format: "%s", + formatArgs: "0.000000002", + expectedOutput: "0.000000002", + expectedRuntimeCost: 11, + expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, + }, + { + name: "default list scientific notation for string", + format: "%s", + formatArgs: "[0.000000002]", + expectedOutput: "[0.000000002]", + expectedRuntimeCost: 21, + expectedEstimatedCost: checker.CostEstimate{Min: 21, Max: 21}, + }, + { + name: "NaN support for fixed-point", + format: "%f", + formatArgs: `double("NaN")`, + expectedOutput: "NaN", + expectedRuntimeCost: 12, + expectedEstimatedCost: checker.CostEstimate{Min: 12, Max: 12}, + }, + { + name: "positive infinity support for fixed-point", + format: "%f", + formatArgs: `double("Infinity")`, + expectedOutput: "Infinity", + expectedRuntimeCost: 12, + expectedEstimatedCost: checker.CostEstimate{Min: 12, Max: 12}, + }, + { + name: "negative infinity support for fixed-point", + format: "%f", + formatArgs: `double("-Infinity")`, + expectedOutput: "-Infinity", + expectedRuntimeCost: 12, + expectedEstimatedCost: checker.CostEstimate{Min: 12, Max: 12}, + }, + { + name: "NaN support for string", + format: "%s", + formatArgs: `double("NaN")`, + expectedOutput: "NaN", + }, + { + name: "positive infinity support for string", + format: "%s", + formatArgs: `double("Infinity")`, + expectedOutput: "Infinity", + }, + { + name: "negative infinity support for string", + format: "%s", + formatArgs: `double("-Infinity")`, + expectedOutput: "-Infinity", + }, + { + name: "infinity list support for string", + format: "%s", + formatArgs: `[double("NaN"),double("+Infinity"), double("-Infinity")]`, + expectedOutput: `[NaN, Infinity, -Infinity]`, + }, + { + name: "uint support for decimal clause", + format: "%d", + formatArgs: "uint(64)", + expectedOutput: "64", + expectedRuntimeCost: 12, + expectedEstimatedCost: checker.CostEstimate{Min: 12, Max: 12}, + }, + { + name: "null support for string", + format: "null: %s", + formatArgs: "null", + expectedOutput: "null: null", + expectedRuntimeCost: 11, + expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, + }, + { + name: "int support for string", + format: "%s", + formatArgs: `999999999999`, + expectedOutput: "999999999999", + expectedRuntimeCost: 11, + expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, + }, + { + name: "bytes support for string", + format: "some bytes: %s", + formatArgs: `b"xyz"`, + expectedOutput: "some bytes: xyz", + expectedRuntimeCost: 12, + expectedEstimatedCost: checker.CostEstimate{Min: 12, Max: 12}, + }, + { + name: "type() support for string", + format: "type is %s", + formatArgs: `type("test string")`, + expectedOutput: "type is string", + expectedRuntimeCost: 12, + expectedEstimatedCost: checker.CostEstimate{Min: 12, Max: 12}, + }, + { + name: "timestamp support for string", + format: "%s", + formatArgs: `timestamp("2023-02-03T23:31:20+00:00")`, + expectedOutput: "2023-02-03T23:31:20Z", + expectedRuntimeCost: 12, + expectedEstimatedCost: checker.CostEstimate{Min: 12, Max: 12}, + }, + { + name: "duration support for string", + format: "%s", + formatArgs: `duration("1h45m47s")`, + expectedOutput: "6347s", + expectedRuntimeCost: 12, + expectedEstimatedCost: checker.CostEstimate{Min: 12, Max: 12}, + }, + { + name: "small duration support for string", + format: "%s", + formatArgs: `duration("2ns")`, + expectedOutput: "0.000000002s", + expectedRuntimeCost: 12, + expectedEstimatedCost: checker.CostEstimate{Min: 12, Max: 12}, + }, + { + name: "list support for string", + format: "%s", + formatArgs: `["abc", 3.14, null, [9, 8, 7, 6], timestamp("2023-02-03T23:31:20Z")]`, + expectedOutput: `[abc, 3.14, null, [9, 8, 7, 6], 2023-02-03T23:31:20Z]`, + expectedRuntimeCost: 32, + expectedEstimatedCost: checker.CostEstimate{Min: 32, Max: 32}, + }, + { + name: "map support for string", + format: "%s", + formatArgs: `{"key1": b"xyz", "key5": null, "key2": duration("2h"), "key4": true, "key3": 2.71828}`, + expectedOutput: `{key1: xyz, key2: 7200s, key3: 2.71828, key4: true, key5: null}`, + expectedRuntimeCost: 42, + expectedEstimatedCost: checker.CostEstimate{Min: 42, Max: 42}, + }, + { + name: "map support (all key types)", + format: "map with multiple key types: %s", + formatArgs: `{1: "value1", uint(2): "value2", true: double("NaN")}`, + expectedOutput: `map with multiple key types: {1: value1, 2: value2, true: NaN}`, + expectedRuntimeCost: 46, + expectedEstimatedCost: checker.CostEstimate{Min: 46, Max: 46}, + }, + { + name: "boolean support for %s", + format: "true bool: %s, false bool: %s", + formatArgs: `true, false`, + expectedOutput: "true bool: true, false bool: false", + expectedRuntimeCost: 13, + expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13}, + }, + { + name: "dyntype support for string formatting clause", + format: "dynamic string: %s", + formatArgs: `dynStr`, + dynArgs: map[string]any{ + "dynStr": "a string", + }, + expectedOutput: "dynamic string: a string", + expectedRuntimeCost: 13, + expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13}, + }, + { + name: "dyntype support for numbers with string formatting clause", + format: "dynIntStr: %s dynDoubleStr: %s", + formatArgs: `dynIntStr, dynDoubleStr`, + dynArgs: map[string]any{ + "dynIntStr": 32, + "dynDoubleStr": 56.8, + }, + expectedOutput: "dynIntStr: 32 dynDoubleStr: 56.8", + expectedRuntimeCost: 15, + expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, + }, + { + name: "dyntype support for integer formatting clause", + format: "dynamic int: %d", + formatArgs: `dynInt`, + dynArgs: map[string]any{ + "dynInt": 128, + }, + expectedOutput: "dynamic int: 128", + expectedRuntimeCost: 13, + expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13}, + }, + { + name: "dyntype support for integer formatting clause (unsigned)", + format: "dynamic unsigned int: %d", + formatArgs: `dynUnsignedInt`, + dynArgs: map[string]any{ + "dynUnsignedInt": uint64(256), + }, + expectedOutput: "dynamic unsigned int: 256", + expectedRuntimeCost: 14, + expectedEstimatedCost: checker.CostEstimate{Min: 14, Max: 14}, + }, + { + name: "dyntype support for hex formatting clause", + format: "dynamic hex int: %x", + formatArgs: `dynHexInt`, + dynArgs: map[string]any{ + "dynHexInt": 22, + }, + expectedOutput: "dynamic hex int: 16", + expectedRuntimeCost: 13, + expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13}, + }, + { + name: "dyntype support for hex formatting clause (uppercase)", + format: "dynamic hex int: %X (uppercase)", + formatArgs: `dynHexInt`, + dynArgs: map[string]any{ + "dynHexInt": 26, + }, + expectedOutput: "dynamic hex int: 1A (uppercase)", + expectedRuntimeCost: 15, + expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, + }, + { + name: "dyntype support for unsigned hex formatting clause", + format: "dynamic hex int: %x (unsigned)", + formatArgs: `dynUnsignedHexInt`, + dynArgs: map[string]any{ + "dynUnsignedHexInt": uint(500), + }, + expectedOutput: "dynamic hex int: 1f4 (unsigned)", + expectedRuntimeCost: 14, + expectedEstimatedCost: checker.CostEstimate{Min: 14, Max: 14}, + }, + { + name: "dyntype support for fixed-point formatting clause", + format: "dynamic double: %.3f", + formatArgs: `dynDouble`, + dynArgs: map[string]any{ + "dynDouble": 4.5, + }, + expectedOutput: "dynamic double: 4.500", + expectedRuntimeCost: 13, + expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13}, + }, + { + name: "dyntype support for scientific notation", + format: "(dyntype) e: %e", + formatArgs: "dynE", + dynArgs: map[string]any{ + "dynE": 2.71828, + }, + expectedOutput: "(dyntype) e: 2.718280e+00", + expectedRuntimeCost: 13, + expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13}, + }, + { + name: "dyntype NaN/infinity support for fixed-point", + format: "NaN: %f, infinity: %f", + formatArgs: `dynNaN, dynInf`, + dynArgs: map[string]any{ + "dynNaN": math.NaN(), + "dynInf": math.Inf(1), + }, + expectedOutput: "NaN: NaN, infinity: Infinity", + expectedRuntimeCost: 15, + expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, + }, + { + name: "dyntype support for timestamp", + format: "dyntype timestamp: %s", + formatArgs: `dynTime`, + dynArgs: map[string]any{ + "dynTime": time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC), + }, + expectedOutput: "dyntype timestamp: 2009-11-10T23:00:00Z", + expectedRuntimeCost: 14, + expectedEstimatedCost: checker.CostEstimate{Min: 14, Max: 14}, + }, + { + name: "dyntype support for duration", + format: "dyntype duration: %s", + formatArgs: `dynDuration`, + dynArgs: map[string]any{ + "dynDuration": mustParseDuration("2h25m47s"), + }, + expectedOutput: "dyntype duration: 8747s", + expectedRuntimeCost: 13, + expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13}, + }, + { + name: "dyntype support for lists", + format: "dyntype list: %s", + formatArgs: `dynList`, + dynArgs: map[string]any{ + "dynList": []any{6, 4.2, "a string"}, + }, + expectedOutput: `dyntype list: [6, 4.2, a string]`, + expectedRuntimeCost: 13, + expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13}, + }, + { + name: "dyntype support for maps", + format: "dyntype map: %s", + formatArgs: `dynMap`, + dynArgs: map[string]any{ + "dynMap": map[any]any{ + "strKey": "x", + true: 42, + int64(6): mustParseDuration("7m2s"), + }, + }, + expectedOutput: `dyntype map: {6: 422s, strKey: x, true: 42}`, + expectedRuntimeCost: 13, + expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13}, + }, + { + name: "message field support", + format: "message field msg.single_int32: %d, msg.single_double: %.1f", + formatArgs: `msg.single_int32, msg.single_double`, + dynArgs: map[string]any{ + "msg": &proto3pb.TestAllTypes{ + SingleInt32: 2, + SingleDouble: 1.0, + }, + }, + expectedOutput: `message field msg.single_int32: 2, msg.single_double: 1.0`, + }, + { + name: "unrecognized formatting clause", + format: "%a", + formatArgs: "1", + skipCompileCheck: true, + err: "could not parse formatting clause: unrecognized formatting clause \"a\"", + }, + { + name: "out of bounds arg index", + format: "%d %d %d", + formatArgs: "0, 1", + skipCompileCheck: true, + err: "index 2 out of range", + }, + { + name: "string substitution is not allowed with binary clause", + format: "string is %b", + formatArgs: `"abc"`, + skipCompileCheck: true, + err: "error during formatting: only ints, uints, and bools can be formatted as binary, was given string", + }, + { + name: "duration substitution not allowed with decimal clause", + format: "%d", + formatArgs: `duration("30m2s")`, + skipCompileCheck: true, + err: "error during formatting: decimal clause can only be used on ints, uints, and doubles, was given google.protobuf.Duration", + }, + { + name: "string substitution not allowed with octal clause", + format: "octal: %o", + formatArgs: `"a string"`, + skipCompileCheck: true, + err: "error during formatting: octal clause can only be used on ints and uints, was given string", + }, + { + name: "double substitution not allowed with hex clause", + format: "double is %x", + formatArgs: "0.5", + skipCompileCheck: true, + err: "error during formatting: only ints, uints, bytes, and strings can be formatted as hex, was given double", + }, + { + name: "uppercase not allowed for scientific clause", + format: "double is %E", + formatArgs: "0.5", + skipCompileCheck: true, + err: `could not parse formatting clause: unrecognized formatting clause "E"`, + }, + { + name: "object not allowed", + format: "object is %s", + formatArgs: `ext.TestAllTypes{PbVal: test.TestAllTypes{}}`, + skipCompileCheck: true, + err: "error during formatting: string clause can only be used on strings, bools, bytes, ints, doubles, maps, lists, types, durations, and timestamps, was given ext.TestAllTypes", + }, + { + name: "object inside list", + format: "%s", + formatArgs: "[1, 2, ext.TestAllTypes{PbVal: test.TestAllTypes{}}]", + skipCompileCheck: true, + err: "error during formatting: string clause can only be used on strings, bools, bytes, ints, doubles, maps, lists, types, durations, and timestamps, was given ext.TestAllTypes", + }, + { + name: "object inside map", + format: "%s", + formatArgs: `{1: "a", 2: ext.TestAllTypes{}}`, + skipCompileCheck: true, + err: "error during formatting: string clause can only be used on strings, bools, bytes, ints, doubles, maps, lists, types, durations, and timestamps, was given ext.TestAllTypes", + }, + { + name: "null not allowed for %d", + format: "null: %d", + formatArgs: "null", + skipCompileCheck: true, + err: "error during formatting: decimal clause can only be used on ints, uints, and doubles, was given null_type", + }, + { + name: "null not allowed for %e", + format: "null: %e", + formatArgs: "null", + skipCompileCheck: true, + err: "error during formatting: scientific clause can only be used on ints, uints, and doubles, was given null_type", + }, + { + name: "null not allowed for %f", + format: "null: %f", + formatArgs: "null", + skipCompileCheck: true, + err: "error during formatting: fixed-point clause can only be used on ints, uints, and doubles, was given null_type", + }, + { + name: "null not allowed for %x", + format: "null: %x", + formatArgs: "null", + skipCompileCheck: true, + err: "error during formatting: only ints, uints, bytes, and strings can be formatted as hex, was given null_type", + }, + { + name: "null not allowed for %X", + format: "null: %X", + formatArgs: "null", + skipCompileCheck: true, + err: "error during formatting: only ints, uints, bytes, and strings can be formatted as hex, was given null_type", + }, + { + name: "null not allowed for %b", + format: "null: %b", + formatArgs: "null", + skipCompileCheck: true, + err: "error during formatting: only ints, uints, and bools can be formatted as binary, was given null_type", + }, + { + name: "null not allowed for %o", + format: "null: %o", + formatArgs: "null", + skipCompileCheck: true, + err: "error during formatting: octal clause can only be used on ints and uints, was given null_type", + }, + { + name: "compile-time cardinality check (too few for string)", + format: "%s %s", + formatArgs: `"abc"`, + err: "index 1 out of range", + }, + { + name: "compile-time cardinality check (too many for string)", + format: "%s %s", + formatArgs: `"abc", "def", "ghi"`, + err: "too many arguments supplied to string.format (expected 2, got 3)", + }, + { + name: "compile-time syntax check (unexpected end of string)", + format: "filler %", + formatArgs: "", + err: "unexpected end of string", + }, + { + name: "compile-time syntax check (unrecognized formatting clause)", + format: "%j", + // pass args here, otherwise the cardinality check will fail first + formatArgs: "123", + err: `could not parse formatting clause: unrecognized formatting clause "j"`, + }, + { + name: "compile-time %s check", + format: "object is %s", + formatArgs: `ext.TestAllTypes{PbVal: test.TestAllTypes{}}`, + err: "error during formatting: string clause can only be used on strings, bools, bytes, ints, doubles, maps, lists, types, durations, and timestamps", + }, + { + name: "compile-time check for objects inside list literal", + format: "list is %s", + formatArgs: `[1, 2, ext.TestAllTypes{PbVal: test.TestAllTypes{}}]`, + err: "error during formatting: string clause can only be used on strings, bools, bytes, ints, doubles, maps, lists, types, durations, and timestamps", + }, + { + name: "compile-time %d check", + format: "int is %d", + formatArgs: "null", + err: "error during formatting: decimal clause can only be used on ints, uints, and doubles, was given null_type", + }, + { + name: "compile-time %f check", + format: "double is %f", + formatArgs: "true", + err: "error during formatting: fixed-point clause can only be used on ints, uints, and doubles, was given bool", + }, + { + name: "compile-time precision syntax check", + format: "double is %.34", + formatArgs: "5.0", + err: "could not parse formatting clause: error while parsing precision: could not find end of precision specifier", + }, + { + name: "compile-time %e check", + format: "double is %e", + formatArgs: "true", + err: "error during formatting: scientific clause can only be used on ints, uints, and doubles, was given bool", + }, + { + name: "compile-time %b check", + format: "string is %b", + formatArgs: `"a string"`, + err: "error during formatting: only ints, uints, and bools can be formatted as binary, was given string", + }, + { + name: "compile-time %x check", + format: "%x is a double", + formatArgs: "2.5", + err: "error during formatting: only ints, uints, bytes, and strings can be formatted as hex, was given double", + }, + { + name: "compile-time %X check", + format: "%X is a double", + formatArgs: "2.5", + err: "error during formatting: only ints, uints, bytes, and strings can be formatted as hex, was given double", + }, + { + name: "compile-time %o check", + format: "an octal: %o", + formatArgs: "3.14", + err: "octal clause can only be used on ints and uints, was given double", + }, + } + evalExpr := func(env *cel.Env, expr string, evalArgs any, expectedRuntimeCost uint64, expectedEstimatedCost checker.CostEstimate, t *testing.T) (ref.Val, error) { + t.Logf("evaluating expr: %s", expr) + parsedAst, issues := env.Parse(expr) + if issues.Err() != nil { + t.Fatalf("env.Parse(%v) failed: %v", expr, issues.Err()) + } + checkedAst, issues := env.Check(parsedAst) + if issues.Err() != nil { + return nil, issues.Err() + } + evalOpts := make([]cel.ProgramOption, 0) + costTracker := &noopCostEstimator{} + if expectedRuntimeCost != 0 { + evalOpts = append(evalOpts, cel.CostTracking(costTracker)) + } + program, err := env.Program(checkedAst, evalOpts...) + if err != nil { + return nil, err + } + + actualEstimatedCost, err := env.EstimateCost(checkedAst, costTracker) + if err != nil { + t.Fatal(err) + } + if expectedEstimatedCost.Min != 0 && expectedEstimatedCost.Max != 0 { + if actualEstimatedCost.Min != expectedEstimatedCost.Min && actualEstimatedCost.Max != expectedEstimatedCost.Max { + t.Fatalf("expected estimated cost range to be %v, was %v", expectedEstimatedCost, actualEstimatedCost) + } + } + + var out ref.Val + var details *cel.EvalDetails + if evalArgs != nil { + out, details, err = program.Eval(evalArgs) + } else { + out, details, err = program.Eval(cel.NoVars()) + } + + if expectedRuntimeCost != 0 { + if details == nil { + t.Fatal("no EvalDetails available when runtime cost was expected") + } + if *details.ActualCost() != expectedRuntimeCost { + t.Fatalf("expected runtime cost to be %d, was %d", expectedRuntimeCost, *details.ActualCost()) + } + if expectedEstimatedCost.Min != 0 && expectedEstimatedCost.Max != 0 { + if *details.ActualCost() < expectedEstimatedCost.Min || *details.ActualCost() > expectedEstimatedCost.Max { + t.Fatalf("runtime cost %d outside of expected estimated cost range %v", *details.ActualCost(), expectedEstimatedCost) + } + } + } + return out, err + } + buildVariables := func(vars map[string]any) []cel.EnvOption { + opts := make([]cel.EnvOption, len(vars)) + i := 0 + for name, value := range vars { + t := cel.DynType + switch v := value.(type) { + case proto.Message: + t = cel.ObjectType(string(v.ProtoReflect().Descriptor().FullName())) + case types.Bool: + t = cel.BoolType + case types.Bytes: + t = cel.BytesType + case types.Double: + t = cel.DoubleType + case types.Duration: + t = cel.DurationType + case types.Int: + t = cel.IntType + case types.Null: + t = cel.NullType + case types.String: + t = cel.StringType + case types.Timestamp: + t = cel.TimestampType + case types.Uint: + t = cel.UintType + } + opts[i] = cel.Variable(name, t) + i++ + } + return opts + } + buildOpts := func(skipCompileCheck bool, variables []cel.EnvOption) []cel.EnvOption { + opts := []cel.EnvOption{ + Strings(StringsValidateFormatCalls(!skipCompileCheck)), + cel.Container("ext"), + cel.Abbrevs("google.expr.proto3.test"), + cel.Types(&proto3pb.TestAllTypes{}), + NativeTypes( + reflect.TypeOf(&TestNestedType{}), + reflect.ValueOf(&TestAllTypes{}), + ), + } + opts = append(opts, cel.ASTValidators(cel.ValidateHomogeneousAggregateLiterals())) + opts = append(opts, variables...) + return opts + } + runCase := func(format, formatArgs string, dynArgs map[string]any, skipCompileCheck bool, expectedRuntimeCost uint64, expectedEstimatedCost checker.CostEstimate, t *testing.T) (ref.Val, error) { + env, err := cel.NewEnv(buildOpts(skipCompileCheck, buildVariables(dynArgs))...) + if err != nil { + t.Fatalf("cel.NewEnv() failed: %v", err) + } + expr := fmt.Sprintf("%q.format([%s])", format, formatArgs) + if len(dynArgs) == 0 { + return evalExpr(env, expr, cel.NoVars(), expectedRuntimeCost, expectedEstimatedCost, t) + } + return evalExpr(env, expr, dynArgs, expectedRuntimeCost, expectedEstimatedCost, t) + } + checkCase := func(output ref.Val, expectedOutput string, err error, expectedErr string, t *testing.T) { + if err != nil { + if expectedErr != "" { + if !strings.Contains(err.Error(), expectedErr) { + t.Fatalf("expected %q as error message, got %q", expectedErr, err.Error()) + } + } else { + t.Fatalf("unexpected error: %s", err) + } + } else { + if output.Type() != types.StringType { + t.Fatalf("expected test expr to eval to string (got %s instead)", output.Type().TypeName()) + } else { + outputStr := output.Value().(string) + if outputStr != expectedOutput { + t.Errorf("expected %q as output, got %q", expectedOutput, outputStr) + } + } + } + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + out, err := runCase(tt.format, tt.formatArgs, tt.dynArgs, tt.skipCompileCheck, tt.expectedRuntimeCost, tt.expectedEstimatedCost, t) + checkCase(out, tt.expectedOutput, err, tt.err, t) + }) + } +} + +func TestStringFormatHeterogeneousLiteralsV2(t *testing.T) { + tests := []struct { + expr string + out string + }{ + { + expr: `"list: %s".format([[[1, 2, [3.0, 4]]]])`, + out: `list: [[1, 2, [3, 4]]]`, + }, + { + expr: `"list size: %d".format([[[1, 2, [3.0, 4]]].size()])`, + out: `list size: 1`, + }, + { + expr: `"list element: %s".format([[[1, 2, [3.0, 4]]][0]])`, + out: `list element: [1, 2, [3, 4]]`, + }, + } + env, err := cel.NewEnv(Strings(), cel.ASTValidators(cel.ValidateHomogeneousAggregateLiterals())) + if err != nil { + t.Fatalf("cel.NewEnv() failed: %v", err) + } + for _, tst := range tests { + tc := tst + t.Run(tc.expr, func(t *testing.T) { + ast, iss := env.Compile(tc.expr) + if iss.Err() != nil { + t.Fatalf("env.Compile(%q) failed: %v", tc.expr, iss.Err()) + } + prg, err := env.Program(ast) + if err != nil { + t.Fatalf("env.Program() failed: %v", err) + } + out, _, err := prg.Eval(cel.NoVars()) + if err != nil { + t.Fatalf("Eval() failed: %v", err) + } + if out.Value() != tc.out { + t.Errorf("Eval() got %v, wanted %v", out, tc.out) + } + }) + } +} diff --git a/ext/strings.go b/ext/strings.go index 2e590a4c5..88b4d7f03 100644 --- a/ext/strings.go +++ b/ext/strings.go @@ -286,10 +286,15 @@ const ( // // 'gums'.reverse() // returns 'smug' // 'John Smith'.reverse() // returns 'htimS nhoJ' +// +// Introduced at version: 4 +// +// Formatting updated to adhere to https://github.com/google/cel-spec/blob/master/doc/extensions/strings.md. +// +// .format() -> func Strings(options ...StringsOption) cel.EnvOption { s := &stringLib{ - version: math.MaxUint32, - validateFormat: true, + version: math.MaxUint32, } for _, o := range options { s = o(s) @@ -298,9 +303,8 @@ func Strings(options ...StringsOption) cel.EnvOption { } type stringLib struct { - locale string - version uint32 - validateFormat bool + locale string + version uint32 } // LibraryName implements the SingletonLibrary interface method. @@ -314,6 +318,8 @@ type StringsOption func(*stringLib) *stringLib // StringsLocale configures the library with the given locale. The locale tag will // be checked for validity at the time that EnvOptions are configured. If this option // is not passed, string.format will behave as if en_US was passed as the locale. +// +// If StringsVersion is greater than or equal to 4, this option is ignored. func StringsLocale(locale string) StringsOption { return func(sl *stringLib) *stringLib { sl.locale = locale @@ -340,10 +346,9 @@ func StringsVersion(version uint32) StringsOption { // StringsValidateFormatCalls validates type-checked ASTs to ensure that string.format() calls have // valid formatting clauses and valid argument types for each clause. // -// Enabled by default. +// Deprecated func StringsValidateFormatCalls(value bool) StringsOption { return func(s *stringLib) *stringLib { - s.validateFormat = value return s } } @@ -351,7 +356,7 @@ func StringsValidateFormatCalls(value bool) StringsOption { // CompileOptions implements the Library interface method. func (lib *stringLib) CompileOptions() []cel.EnvOption { formatLocale := "en_US" - if lib.locale != "" { + if lib.version < 4 && lib.locale != "" { // ensure locale is properly-formed if set _, err := language.Parse(lib.locale) if err != nil { @@ -466,21 +471,29 @@ func (lib *stringLib) CompileOptions() []cel.EnvOption { }))), } if lib.version >= 1 { - opts = append(opts, cel.Function("format", - cel.MemberOverload("string_format", []*cel.Type{cel.StringType, cel.ListType(cel.DynType)}, cel.StringType, - cel.FunctionBinding(func(args ...ref.Val) ref.Val { - s := string(args[0].(types.String)) - formatArgs := args[1].(traits.Lister) - return stringOrError(parseFormatString(s, &stringFormatter{}, &stringArgList{formatArgs}, formatLocale)) - }))), + if lib.version >= 4 { + opts = append(opts, cel.Function("format", + cel.MemberOverload("string_format", []*cel.Type{cel.StringType, cel.ListType(cel.DynType)}, cel.StringType, + cel.FunctionBinding(func(args ...ref.Val) ref.Val { + s := string(args[0].(types.String)) + formatArgs := args[1].(traits.Lister) + return stringOrError(parseFormatStringV2(s, &stringFormatterV2{}, &stringArgList{formatArgs})) + })))) + } else { + opts = append(opts, cel.Function("format", + cel.MemberOverload("string_format", []*cel.Type{cel.StringType, cel.ListType(cel.DynType)}, cel.StringType, + cel.FunctionBinding(func(args ...ref.Val) ref.Val { + s := string(args[0].(types.String)) + formatArgs := args[1].(traits.Lister) + return stringOrError(parseFormatString(s, &stringFormatter{}, &stringArgList{formatArgs}, formatLocale)) + })))) + } + opts = append(opts, cel.Function("strings.quote", cel.Overload("strings_quote", []*cel.Type{cel.StringType}, cel.StringType, cel.UnaryBinding(func(str ref.Val) ref.Val { s := str.(types.String) return stringOrError(quote(string(s))) - }))), - - cel.ASTValidators(stringFormatValidator{})) - + })))) } if lib.version >= 2 { opts = append(opts, @@ -529,8 +542,12 @@ func (lib *stringLib) CompileOptions() []cel.EnvOption { }))), ) } - if lib.validateFormat { - opts = append(opts, cel.ASTValidators(stringFormatValidator{})) + if lib.version >= 1 { + if lib.version >= 4 { + opts = append(opts, cel.ASTValidators(stringFormatValidatorV2{})) + } else { + opts = append(opts, cel.ASTValidators(stringFormatValidator{})) + } } return opts } diff --git a/ext/strings_test.go b/ext/strings_test.go index 22aab84e6..3a8adeb09 100644 --- a/ext/strings_test.go +++ b/ext/strings_test.go @@ -16,21 +16,15 @@ package ext import ( "fmt" - "math" - "reflect" "strings" "testing" "time" "unicode/utf8" - "google.golang.org/protobuf/proto" - "github.com/google/cel-go/cel" "github.com/google/cel-go/checker" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" - - proto3pb "github.com/google/cel-go/test/proto3pb" ) // TODO: move these tests to a conformance test. @@ -444,1109 +438,6 @@ func TestStringsWithExtension(t *testing.T) { } } -func TestStringFormat(t *testing.T) { - tests := []struct { - name string - format string - dynArgs map[string]any - formatArgs string - locale string - err string - expectedOutput string - expectedRuntimeCost uint64 - expectedEstimatedCost checker.CostEstimate - skipCompileCheck bool - }{ - { - name: "no-op", - format: "no substitution", - expectedOutput: "no substitution", - expectedRuntimeCost: 12, - expectedEstimatedCost: checker.CostEstimate{Min: 12, Max: 12}, - }, - - { - name: "mid-string substitution", - format: "str is %s and some more", - formatArgs: `"filler"`, - expectedOutput: "str is filler and some more", - expectedRuntimeCost: 13, - expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13}, - }, - { - name: "percent escaping", - format: "%% and also %%", - expectedOutput: "% and also %", - expectedRuntimeCost: 12, - expectedEstimatedCost: checker.CostEstimate{Min: 12, Max: 12}, - }, - { - name: "substution inside escaped percent signs", - format: "%%%s%%", - formatArgs: `"text"`, - expectedOutput: "%text%", - expectedRuntimeCost: 11, - expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, - }, - { - name: "substitution with one escaped percent sign on the right", - format: "%s%%", - formatArgs: `"percent on the right"`, - expectedOutput: "percent on the right%", - expectedRuntimeCost: 11, - expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, - }, - { - name: "substitution with one escaped percent sign on the left", - format: "%%%s", - formatArgs: `"percent on the left"`, - expectedOutput: "%percent on the left", - expectedRuntimeCost: 11, - expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, - }, - { - name: "multiple substitutions", - format: "%d %d %d, %s %s %s, %d %d %d, %s %s %s", - formatArgs: `1, 2, 3, "A", "B", "C", 4, 5, 6, "D", "E", "F"`, - expectedOutput: "1 2 3, A B C, 4 5 6, D E F", - expectedRuntimeCost: 14, - expectedEstimatedCost: checker.CostEstimate{Min: 14, Max: 14}, - }, - { - name: "percent sign escape sequence support", - format: "\u0025\u0025escaped \u0025s\u0025\u0025", - formatArgs: `"percent"`, - expectedOutput: "%escaped percent%", - expectedRuntimeCost: 12, - expectedEstimatedCost: checker.CostEstimate{Min: 12, Max: 12}, - }, - { - name: "fixed point formatting clause", - format: "%.3f", - formatArgs: "1.2345", - expectedOutput: "1.234", - expectedRuntimeCost: 11, - expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, - locale: "en_US", - }, - { - name: "binary formatting clause", - format: "this is 5 in binary: %b", - formatArgs: "5", - expectedOutput: "this is 5 in binary: 101", - expectedRuntimeCost: 13, - expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13}, - }, - { - name: "negative binary formatting clause", - format: "this is -5 in binary: %b", - formatArgs: "-5", - expectedOutput: "this is -5 in binary: -101", - expectedRuntimeCost: 13, - expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13}, - }, - { - name: "uint support for binary formatting", - format: "unsigned 64 in binary: %b", - formatArgs: "uint(64)", - expectedOutput: "unsigned 64 in binary: 1000000", - expectedRuntimeCost: 14, - expectedEstimatedCost: checker.CostEstimate{Min: 14, Max: 14}, - }, - { - name: "bool support for binary formatting", - format: "bit set from bool: %b", - formatArgs: "true", - expectedOutput: "bit set from bool: 1", - expectedRuntimeCost: 13, - expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13}, - }, - { - name: "octal formatting clause", - format: "%o", - formatArgs: "11", - expectedOutput: "13", - expectedRuntimeCost: 11, - expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, - }, - { - name: "negative octal formatting clause", - format: "%o", - formatArgs: "-11", - expectedOutput: "-13", - expectedRuntimeCost: 11, - expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, - }, - { - name: "uint support for octal formatting clause", - format: "this is an unsigned octal: %o", - formatArgs: "uint(65535)", - expectedOutput: "this is an unsigned octal: 177777", - expectedRuntimeCost: 14, - expectedEstimatedCost: checker.CostEstimate{Min: 14, Max: 14}, - }, - { - name: "lowercase hexadecimal formatting clause", - format: "%x is 30 in hexadecimal", - formatArgs: "30", - expectedOutput: "1e is 30 in hexadecimal", - expectedRuntimeCost: 13, - expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13}, - }, - { - name: "uppercase hexadecimal formatting clause", - format: "%X is 20 in hexadecimal", - formatArgs: "30", - expectedOutput: "1E is 20 in hexadecimal", - expectedRuntimeCost: 13, - expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13}, - }, - { - name: "negative hexadecimal formatting clause", - format: "%x is -30 in hexadecimal", - formatArgs: "-30", - expectedOutput: "-1e is -30 in hexadecimal", - expectedRuntimeCost: 13, - expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13}, - }, - { - name: "unsigned support for hexadecimal formatting clause", - format: "%X is 6000 in hexadecimal", - formatArgs: "uint(6000)", - expectedOutput: "1770 is 6000 in hexadecimal", - expectedRuntimeCost: 14, - expectedEstimatedCost: checker.CostEstimate{Min: 14, Max: 14}, - }, - { - name: "string support with hexadecimal formatting clause", - format: "%x", - formatArgs: `"Hello world!"`, - expectedOutput: "48656c6c6f20776f726c6421", - expectedRuntimeCost: 11, - expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, - }, - { - name: "string support with uppercase hexadecimal formatting clause", - format: "%X", - formatArgs: `"Hello world!"`, - expectedOutput: "48656C6C6F20776F726C6421", - expectedRuntimeCost: 11, - expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, - }, - { - name: "byte support with hexadecimal formatting clause", - format: "%x", - formatArgs: `b"byte string"`, - expectedOutput: "6279746520737472696e67", - expectedRuntimeCost: 11, - expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, - }, - { - name: "byte support with hexadecimal formatting clause leading zero", - format: "%x", - formatArgs: `b"\x00\x00byte string\x00"`, - expectedOutput: "00006279746520737472696e6700", - expectedRuntimeCost: 11, - expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, - }, - { - name: "byte support with uppercase hexadecimal formatting clause", - format: "%X", - formatArgs: `b"byte string"`, - expectedOutput: "6279746520737472696E67", - expectedRuntimeCost: 11, - expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, - }, - { - name: "scientific notation formatting clause", - format: "%.6e", - formatArgs: "1052.032911275", - expectedOutput: "1.052033\u202f\u00d7\u202f10\u2070\u00b3", - expectedRuntimeCost: 11, - expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, - locale: "en_US", - }, - { - name: "locale support", - format: "%.3f", - formatArgs: "3.14", - locale: "fr_FR", - expectedOutput: "3,140", - expectedRuntimeCost: 11, - expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, - }, - { - name: "default precision for fixed-point clause", - format: "%f", - formatArgs: "2.71828", - expectedOutput: "2.718280", - expectedRuntimeCost: 11, - expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, - locale: "en_US", - }, - { - name: "default precision for scientific notation", - format: "%e", - formatArgs: "2.71828", - expectedOutput: "2.718280\u202f\u00d7\u202f10\u2070\u2070", - expectedRuntimeCost: 11, - expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, - locale: "en_US", - }, - { - name: "default precision for string", - format: "%s", - formatArgs: "2.71", - expectedOutput: "2.71", - expectedRuntimeCost: 11, - expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, - locale: "en_US", - }, - { - name: "default list precision for string", - format: "%s", - formatArgs: "[2.71]", - expectedOutput: "[2.710000]", - expectedRuntimeCost: 21, - expectedEstimatedCost: checker.CostEstimate{Min: 21, Max: 21}, - locale: "en_US", - }, - { - name: "default scientific notation for string", - format: "%s", - formatArgs: "0.000000002", - expectedOutput: "2e-09", - expectedRuntimeCost: 11, - expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, - locale: "en_US", - }, - { - name: "default list scientific notation for string", - format: "%s", - formatArgs: "[0.000000002]", - expectedOutput: "[0.000000]", - expectedRuntimeCost: 21, - expectedEstimatedCost: checker.CostEstimate{Min: 21, Max: 21}, - locale: "en_US", - }, - { - name: "unicode output for scientific notation", - format: "unescaped unicode: %e, escaped unicode: %e", - formatArgs: "2.71828, 2.71828", - expectedOutput: "unescaped unicode: 2.718280 × 10⁰⁰, escaped unicode: 2.718280\u202f\u00d7\u202f10\u2070\u2070", - expectedRuntimeCost: 15, - expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, - locale: "en_US", - }, - { - name: "NaN support for fixed-point", - format: "%f", - formatArgs: `"NaN"`, - expectedOutput: "NaN", - expectedRuntimeCost: 11, - expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, - locale: "en_US", - }, - { - name: "positive infinity support for fixed-point", - format: "%f", - formatArgs: `"Infinity"`, - expectedOutput: "∞", - expectedRuntimeCost: 11, - expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, - locale: "en_US", - }, - { - name: "negative infinity support for fixed-point", - format: "%f", - formatArgs: `"-Infinity"`, - expectedOutput: "-∞", - expectedRuntimeCost: 11, - expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, - locale: "en_US", - }, - { - name: "NaN support for string", - format: "%s", - formatArgs: `double("NaN")`, - expectedOutput: "NaN", - }, - { - name: "positive infinity support for string", - format: "%s", - formatArgs: `double("Inf")`, - expectedOutput: "+Inf", - }, - { - name: "negative infinity support for string", - format: "%s", - formatArgs: `double("-Inf")`, - expectedOutput: "-Inf", - }, - { - name: "infinity list support for string", - format: "%s", - formatArgs: `[double("NaN"),double("+Inf"), double("-Inf")]`, - expectedOutput: `["NaN", "+Inf", "-Inf"]`, - }, - { - name: "uint support for decimal clause", - format: "%d", - formatArgs: "uint(64)", - expectedOutput: "64", - expectedRuntimeCost: 12, - expectedEstimatedCost: checker.CostEstimate{Min: 12, Max: 12}, - }, - { - name: "null support for string", - format: "null: %s", - formatArgs: "null", - expectedOutput: "null: null", - expectedRuntimeCost: 11, - expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, - }, - { - name: "int support for string", - format: "%s", - formatArgs: `999999999999`, - expectedOutput: "999999999999", - expectedRuntimeCost: 11, - expectedEstimatedCost: checker.CostEstimate{Min: 11, Max: 11}, - }, - { - name: "bytes support for string", - format: "some bytes: %s", - formatArgs: `b"xyz"`, - expectedOutput: "some bytes: xyz", - expectedRuntimeCost: 12, - expectedEstimatedCost: checker.CostEstimate{Min: 12, Max: 12}, - }, - { - name: "type() support for string", - format: "type is %s", - formatArgs: `type("test string")`, - expectedOutput: "type is string", - expectedRuntimeCost: 12, - expectedEstimatedCost: checker.CostEstimate{Min: 12, Max: 12}, - }, - { - name: "timestamp support for string", - format: "%s", - formatArgs: `timestamp("2023-02-03T23:31:20+00:00")`, - expectedOutput: "2023-02-03T23:31:20Z", - expectedRuntimeCost: 12, - expectedEstimatedCost: checker.CostEstimate{Min: 12, Max: 12}, - }, - { - name: "duration support for string", - format: "%s", - formatArgs: `duration("1h45m47s")`, - expectedOutput: "6347s", - expectedRuntimeCost: 12, - expectedEstimatedCost: checker.CostEstimate{Min: 12, Max: 12}, - }, - { - name: "small duration support for string", - format: "%s", - formatArgs: `duration("2ns")`, - expectedOutput: "0.000000002s", - expectedRuntimeCost: 12, - expectedEstimatedCost: checker.CostEstimate{Min: 12, Max: 12}, - }, - { - name: "list support for string", - format: "%s", - formatArgs: `["abc", 3.14, null, [9, 8, 7, 6], timestamp("2023-02-03T23:31:20Z")]`, - expectedOutput: `["abc", 3.140000, null, [9, 8, 7, 6], timestamp("2023-02-03T23:31:20Z")]`, - expectedRuntimeCost: 32, - expectedEstimatedCost: checker.CostEstimate{Min: 32, Max: 32}, - }, - { - name: "map support for string", - format: "%s", - formatArgs: `{"key1": b"xyz", "key5": null, "key2": duration("2h"), "key4": true, "key3": 2.71828}`, - locale: "nl_NL", - expectedOutput: `{"key1":b"xyz", "key2":duration("7200s"), "key3":2.718280, "key4":true, "key5":null}`, - expectedRuntimeCost: 42, - expectedEstimatedCost: checker.CostEstimate{Min: 42, Max: 42}, - }, - { - name: "map support (all key types)", - format: "map with multiple key types: %s", - formatArgs: `{1: "value1", uint(2): "value2", true: double("NaN")}`, - expectedOutput: `map with multiple key types: {1:"value1", 2:"value2", true:"NaN"}`, - expectedRuntimeCost: 46, - expectedEstimatedCost: checker.CostEstimate{Min: 46, Max: 46}, - }, - { - name: "boolean support for %s", - format: "true bool: %s, false bool: %s", - formatArgs: `true, false`, - expectedOutput: "true bool: true, false bool: false", - expectedRuntimeCost: 13, - expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13}, - }, - { - name: "dyntype support for string formatting clause", - format: "dynamic string: %s", - formatArgs: `dynStr`, - dynArgs: map[string]any{ - "dynStr": "a string", - }, - expectedOutput: "dynamic string: a string", - expectedRuntimeCost: 13, - expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13}, - }, - { - name: "dyntype support for numbers with string formatting clause", - format: "dynIntStr: %s dynDoubleStr: %s", - formatArgs: `dynIntStr, dynDoubleStr`, - dynArgs: map[string]any{ - "dynIntStr": 32, - "dynDoubleStr": 56.8, - }, - expectedOutput: "dynIntStr: 32 dynDoubleStr: 56.8", - expectedRuntimeCost: 15, - expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, - locale: "en_US", - }, - { - name: "dyntype support for integer formatting clause", - format: "dynamic int: %d", - formatArgs: `dynInt`, - dynArgs: map[string]any{ - "dynInt": 128, - }, - expectedOutput: "dynamic int: 128", - expectedRuntimeCost: 13, - expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13}, - }, - { - name: "dyntype support for integer formatting clause (unsigned)", - format: "dynamic unsigned int: %d", - formatArgs: `dynUnsignedInt`, - dynArgs: map[string]any{ - "dynUnsignedInt": uint64(256), - }, - expectedOutput: "dynamic unsigned int: 256", - expectedRuntimeCost: 14, - expectedEstimatedCost: checker.CostEstimate{Min: 14, Max: 14}, - }, - { - name: "dyntype support for hex formatting clause", - format: "dynamic hex int: %x", - formatArgs: `dynHexInt`, - dynArgs: map[string]any{ - "dynHexInt": 22, - }, - expectedOutput: "dynamic hex int: 16", - expectedRuntimeCost: 13, - expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13}, - }, - { - name: "dyntype support for hex formatting clause (uppercase)", - format: "dynamic hex int: %X (uppercase)", - formatArgs: `dynHexInt`, - dynArgs: map[string]any{ - "dynHexInt": 26, - }, - expectedOutput: "dynamic hex int: 1A (uppercase)", - expectedRuntimeCost: 15, - expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, - }, - { - name: "dyntype support for unsigned hex formatting clause", - format: "dynamic hex int: %x (unsigned)", - formatArgs: `dynUnsignedHexInt`, - dynArgs: map[string]any{ - "dynUnsignedHexInt": uint(500), - }, - expectedOutput: "dynamic hex int: 1f4 (unsigned)", - expectedRuntimeCost: 14, - expectedEstimatedCost: checker.CostEstimate{Min: 14, Max: 14}, - }, - { - name: "dyntype support for fixed-point formatting clause", - format: "dynamic double: %.3f", - formatArgs: `dynDouble`, - dynArgs: map[string]any{ - "dynDouble": 4.5, - }, - expectedOutput: "dynamic double: 4.500", - expectedRuntimeCost: 13, - expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13}, - locale: "en_US", - }, - { - name: "dyntype support for fixed-point formatting clause (comma separator locale)", - format: "dynamic double: %f", - formatArgs: `dynDouble`, - dynArgs: map[string]any{ - "dynDouble": 4.5, - }, - expectedOutput: "dynamic double: 4,500000", - expectedRuntimeCost: 13, - expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13}, - locale: "fr_FR", - }, - { - name: "dyntype support for scientific notation", - format: "(dyntype) e: %e", - formatArgs: "dynE", - dynArgs: map[string]any{ - "dynE": 2.71828, - }, - expectedOutput: "(dyntype) e: 2.718280\u202f\u00d7\u202f10\u2070\u2070", - expectedRuntimeCost: 13, - expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13}, - locale: "en_US", - }, - { - name: "dyntype NaN/infinity support for fixed-point", - format: "NaN: %f, infinity: %f", - formatArgs: `dynNaN, dynInf`, - dynArgs: map[string]any{ - "dynNaN": math.NaN(), - "dynInf": math.Inf(1), - }, - expectedOutput: "NaN: NaN, infinity: ∞", - expectedRuntimeCost: 15, - expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, - }, - { - name: "dyntype support for timestamp", - format: "dyntype timestamp: %s", - formatArgs: `dynTime`, - dynArgs: map[string]any{ - "dynTime": time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC), - }, - expectedOutput: "dyntype timestamp: 2009-11-10T23:00:00Z", - expectedRuntimeCost: 14, - expectedEstimatedCost: checker.CostEstimate{Min: 14, Max: 14}, - }, - { - name: "dyntype support for duration", - format: "dyntype duration: %s", - formatArgs: `dynDuration`, - dynArgs: map[string]any{ - "dynDuration": mustParseDuration("2h25m47s"), - }, - expectedOutput: "dyntype duration: 8747s", - expectedRuntimeCost: 13, - expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13}, - }, - { - name: "dyntype support for lists", - format: "dyntype list: %s", - formatArgs: `dynList`, - dynArgs: map[string]any{ - "dynList": []any{6, 4.2, "a string"}, - }, - expectedOutput: `dyntype list: [6, 4.200000, "a string"]`, - expectedRuntimeCost: 13, - expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13}, - }, - { - name: "dyntype support for maps", - format: "dyntype map: %s", - formatArgs: `dynMap`, - dynArgs: map[string]any{ - "dynMap": map[any]any{ - "strKey": "x", - true: 42, - int64(6): mustParseDuration("7m2s"), - }, - }, - expectedOutput: `dyntype map: {"strKey":"x", 6:duration("422s"), true:42}`, - expectedRuntimeCost: 13, - expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13}, - }, - { - name: "message field support", - format: "message field msg.single_int32: %d, msg.single_double: %.1f", - formatArgs: `msg.single_int32, msg.single_double`, - dynArgs: map[string]any{ - "msg": &proto3pb.TestAllTypes{ - SingleInt32: 2, - SingleDouble: 1.0, - }, - }, - locale: "en_US", - expectedOutput: `message field msg.single_int32: 2, msg.single_double: 1.0`, - }, - { - name: "unrecognized formatting clause", - format: "%a", - formatArgs: "1", - skipCompileCheck: true, - err: "could not parse formatting clause: unrecognized formatting clause \"a\"", - }, - { - name: "out of bounds arg index", - format: "%d %d %d", - formatArgs: "0, 1", - skipCompileCheck: true, - err: "index 2 out of range", - }, - { - name: "string substitution is not allowed with binary clause", - format: "string is %b", - formatArgs: `"abc"`, - skipCompileCheck: true, - err: "error during formatting: only integers and bools can be formatted as binary, was given string", - }, - { - name: "duration substitution not allowed with decimal clause", - format: "%d", - formatArgs: `duration("30m2s")`, - skipCompileCheck: true, - err: "error during formatting: decimal clause can only be used on integers, was given google.protobuf.Duration", - }, - { - name: "string substitution not allowed with octal clause", - format: "octal: %o", - formatArgs: `"a string"`, - skipCompileCheck: true, - err: "error during formatting: octal clause can only be used on integers, was given string", - }, - { - name: "double substitution not allowed with hex clause", - format: "double is %x", - formatArgs: "0.5", - skipCompileCheck: true, - err: "error during formatting: only integers, byte buffers, and strings can be formatted as hex, was given double", - }, - { - name: "uppercase not allowed for scientific clause", - format: "double is %E", - formatArgs: "0.5", - skipCompileCheck: true, - err: `could not parse formatting clause: unrecognized formatting clause "E"`, - }, - { - name: "object not allowed", - format: "object is %s", - formatArgs: `ext.TestAllTypes{PbVal: test.TestAllTypes{}}`, - skipCompileCheck: true, - err: "error during formatting: string clause can only be used on strings, bools, bytes, ints, doubles, maps, lists, types, durations, and timestamps, was given ext.TestAllTypes", - }, - { - name: "object inside list", - format: "%s", - formatArgs: "[1, 2, ext.TestAllTypes{PbVal: test.TestAllTypes{}}]", - skipCompileCheck: true, - err: "error during formatting: string clause can only be used on strings, bools, bytes, ints, doubles, maps, lists, types, durations, and timestamps, was given ext.TestAllTypes", - }, - { - name: "object inside map", - format: "%s", - formatArgs: `{1: "a", 2: ext.TestAllTypes{}}`, - skipCompileCheck: true, - err: "error during formatting: string clause can only be used on strings, bools, bytes, ints, doubles, maps, lists, types, durations, and timestamps, was given ext.TestAllTypes", - }, - { - name: "null not allowed for %d", - format: "null: %d", - formatArgs: "null", - skipCompileCheck: true, - err: "error during formatting: decimal clause can only be used on integers, was given null_type", - }, - { - name: "null not allowed for %e", - format: "null: %e", - formatArgs: "null", - skipCompileCheck: true, - err: "error during formatting: scientific clause can only be used on doubles, was given null_type", - }, - { - name: "null not allowed for %f", - format: "null: %f", - formatArgs: "null", - skipCompileCheck: true, - err: "error during formatting: fixed-point clause can only be used on doubles, was given null_type", - }, - { - name: "null not allowed for %x", - format: "null: %x", - formatArgs: "null", - skipCompileCheck: true, - err: "error during formatting: only integers, byte buffers, and strings can be formatted as hex, was given null_type", - }, - { - name: "null not allowed for %X", - format: "null: %X", - formatArgs: "null", - skipCompileCheck: true, - err: "error during formatting: only integers, byte buffers, and strings can be formatted as hex, was given null_type", - }, - { - name: "null not allowed for %b", - format: "null: %b", - formatArgs: "null", - skipCompileCheck: true, - err: "error during formatting: only integers and bools can be formatted as binary, was given null_type", - }, - { - name: "null not allowed for %o", - format: "null: %o", - formatArgs: "null", - skipCompileCheck: true, - err: "error during formatting: octal clause can only be used on integers, was given null_type", - }, - { - name: "compile-time cardinality check (too few for string)", - format: "%s %s", - formatArgs: `"abc"`, - err: "index 1 out of range", - }, - { - name: "compile-time cardinality check (too many for string)", - format: "%s %s", - formatArgs: `"abc", "def", "ghi"`, - err: "too many arguments supplied to string.format (expected 2, got 3)", - }, - { - name: "compile-time syntax check (unexpected end of string)", - format: "filler %", - formatArgs: "", - err: "unexpected end of string", - }, - { - name: "compile-time syntax check (unrecognized formatting clause)", - format: "%j", - // pass args here, otherwise the cardinality check will fail first - formatArgs: "123", - err: `could not parse formatting clause: unrecognized formatting clause "j"`, - }, - { - name: "compile-time %s check", - format: "object is %s", - formatArgs: `ext.TestAllTypes{PbVal: test.TestAllTypes{}}`, - err: "error during formatting: string clause can only be used on strings, bools, bytes, ints, doubles, maps, lists, types, durations, and timestamps", - }, - { - name: "compile-time check for objects inside list literal", - format: "list is %s", - formatArgs: `[1, 2, ext.TestAllTypes{PbVal: test.TestAllTypes{}}]`, - err: "error during formatting: string clause can only be used on strings, bools, bytes, ints, doubles, maps, lists, types, durations, and timestamps", - }, - { - name: "compile-time %d check", - format: "int is %d", - formatArgs: "5.2", - err: "error during formatting: decimal clause can only be used on integers", - }, - { - name: "compile-time %f check", - format: "double is %f", - formatArgs: "true", - err: "error during formatting: fixed-point clause can only be used on doubles", - }, - { - name: "compile-time precision syntax check", - format: "double is %.34", - formatArgs: "5.0", - err: "could not parse formatting clause: error while parsing precision: could not find end of precision specifier", - }, - { - name: "compile-time %e check", - format: "double is %e", - formatArgs: "true", - err: "error during formatting: scientific clause can only be used on doubles", - }, - { - name: "compile-time %b check", - format: "string is %b", - formatArgs: `"a string"`, - err: "error during formatting: only integers and bools can be formatted as binary", - }, - { - name: "compile-time %x check", - format: "%x is a double", - formatArgs: "2.5", - err: "error during formatting: only integers, byte buffers, and strings can be formatted as hex", - }, - { - name: "compile-time %X check", - format: "%X is a double", - formatArgs: "2.5", - err: "error during formatting: only integers, byte buffers, and strings can be formatted as hex", - }, - { - name: "compile-time %o check", - format: "an octal: %o", - formatArgs: "3.14", - err: "error during formatting: octal clause can only be used on integers", - }, - } - evalExpr := func(env *cel.Env, expr string, evalArgs any, expectedRuntimeCost uint64, expectedEstimatedCost checker.CostEstimate, t *testing.T) (ref.Val, error) { - t.Logf("evaluating expr: %s", expr) - parsedAst, issues := env.Parse(expr) - if issues.Err() != nil { - t.Fatalf("env.Parse(%v) failed: %v", expr, issues.Err()) - } - checkedAst, issues := env.Check(parsedAst) - if issues.Err() != nil { - return nil, issues.Err() - } - evalOpts := make([]cel.ProgramOption, 0) - costTracker := &noopCostEstimator{} - if expectedRuntimeCost != 0 { - evalOpts = append(evalOpts, cel.CostTracking(costTracker)) - } - program, err := env.Program(checkedAst, evalOpts...) - if err != nil { - return nil, err - } - - actualEstimatedCost, err := env.EstimateCost(checkedAst, costTracker) - if err != nil { - t.Fatal(err) - } - if expectedEstimatedCost.Min != 0 && expectedEstimatedCost.Max != 0 { - if actualEstimatedCost.Min != expectedEstimatedCost.Min && actualEstimatedCost.Max != expectedEstimatedCost.Max { - t.Fatalf("expected estimated cost range to be %v, was %v", expectedEstimatedCost, actualEstimatedCost) - } - } - - var out ref.Val - var details *cel.EvalDetails - if evalArgs != nil { - out, details, err = program.Eval(evalArgs) - } else { - out, details, err = program.Eval(cel.NoVars()) - } - - if expectedRuntimeCost != 0 { - if details == nil { - t.Fatal("no EvalDetails available when runtime cost was expected") - } - if *details.ActualCost() != expectedRuntimeCost { - t.Fatalf("expected runtime cost to be %d, was %d", expectedRuntimeCost, *details.ActualCost()) - } - if expectedEstimatedCost.Min != 0 && expectedEstimatedCost.Max != 0 { - if *details.ActualCost() < expectedEstimatedCost.Min || *details.ActualCost() > expectedEstimatedCost.Max { - t.Fatalf("runtime cost %d outside of expected estimated cost range %v", *details.ActualCost(), expectedEstimatedCost) - } - } - } - return out, err - } - buildVariables := func(vars map[string]any) []cel.EnvOption { - opts := make([]cel.EnvOption, len(vars)) - i := 0 - for name, value := range vars { - t := cel.DynType - switch v := value.(type) { - case proto.Message: - t = cel.ObjectType(string(v.ProtoReflect().Descriptor().FullName())) - case types.Bool: - t = cel.BoolType - case types.Bytes: - t = cel.BytesType - case types.Double: - t = cel.DoubleType - case types.Duration: - t = cel.DurationType - case types.Int: - t = cel.IntType - case types.Null: - t = cel.NullType - case types.String: - t = cel.StringType - case types.Timestamp: - t = cel.TimestampType - case types.Uint: - t = cel.UintType - } - opts[i] = cel.Variable(name, t) - i++ - } - return opts - } - buildOpts := func(skipCompileCheck bool, locale string, variables []cel.EnvOption) []cel.EnvOption { - opts := []cel.EnvOption{ - Strings(StringsLocale(locale), StringsValidateFormatCalls(!skipCompileCheck)), - cel.Container("ext"), - cel.Abbrevs("google.expr.proto3.test"), - cel.Types(&proto3pb.TestAllTypes{}), - NativeTypes( - reflect.TypeOf(&TestNestedType{}), - reflect.ValueOf(&TestAllTypes{}), - ), - } - opts = append(opts, cel.ASTValidators(cel.ValidateHomogeneousAggregateLiterals())) - opts = append(opts, variables...) - return opts - } - runCase := func(format, formatArgs, locale string, dynArgs map[string]any, skipCompileCheck bool, expectedRuntimeCost uint64, expectedEstimatedCost checker.CostEstimate, t *testing.T) (ref.Val, error) { - env, err := cel.NewEnv(buildOpts(skipCompileCheck, locale, buildVariables(dynArgs))...) - if err != nil { - t.Fatalf("cel.NewEnv() failed: %v", err) - } - expr := fmt.Sprintf("%q.format([%s])", format, formatArgs) - if len(dynArgs) == 0 { - return evalExpr(env, expr, cel.NoVars(), expectedRuntimeCost, expectedEstimatedCost, t) - } - return evalExpr(env, expr, dynArgs, expectedRuntimeCost, expectedEstimatedCost, t) - } - checkCase := func(output ref.Val, expectedOutput string, err error, expectedErr string, t *testing.T) { - if err != nil { - if expectedErr != "" { - if !strings.Contains(err.Error(), expectedErr) { - t.Fatalf("expected %q as error message, got %q", expectedErr, err.Error()) - } - } else { - t.Fatalf("unexpected error: %s", err) - } - } else { - if output.Type() != types.StringType { - t.Fatalf("expected test expr to eval to string (got %s instead)", output.Type().TypeName()) - } else { - outputStr := output.Value().(string) - if outputStr != expectedOutput { - t.Errorf("expected %q as output, got %q", expectedOutput, outputStr) - } - } - } - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - out, err := runCase(tt.format, tt.formatArgs, tt.locale, tt.dynArgs, tt.skipCompileCheck, tt.expectedRuntimeCost, tt.expectedEstimatedCost, t) - checkCase(out, tt.expectedOutput, err, tt.err, t) - if tt.locale == "" { - // if the test has no locale specified, then that means it - // should have the same output regardless of locale - t.Run("no change on locale", func(t *testing.T) { - out, err := runCase(tt.format, tt.formatArgs, "da_DK", tt.dynArgs, tt.skipCompileCheck, tt.expectedRuntimeCost, tt.expectedEstimatedCost, t) - checkCase(out, tt.expectedOutput, err, tt.err, t) - }) - } - }) - } -} - -func TestStringFormatHeterogeneousLiterals(t *testing.T) { - tests := []struct { - expr string - out string - }{ - { - expr: `"list: %s".format([[[1, 2, [3.0, 4]]]])`, - out: `list: [[1, 2, [3.000000, 4]]]`, - }, - { - expr: `"list size: %d".format([[[1, 2, [3.0, 4]]].size()])`, - out: `list size: 1`, - }, - { - expr: `"list element: %s".format([[[1, 2, [3.0, 4]]][0]])`, - out: `list element: [1, 2, [3.000000, 4]]`, - }, - } - env, err := cel.NewEnv(Strings(), cel.ASTValidators(cel.ValidateHomogeneousAggregateLiterals())) - if err != nil { - t.Fatalf("cel.NewEnv() failed: %v", err) - } - for _, tst := range tests { - tc := tst - t.Run(tc.expr, func(t *testing.T) { - ast, iss := env.Compile(tc.expr) - if iss.Err() != nil { - t.Fatalf("env.Compile(%q) failed: %v", tc.expr, iss.Err()) - } - prg, err := env.Program(ast) - if err != nil { - t.Fatalf("env.Program() failed: %v", err) - } - out, _, err := prg.Eval(cel.NoVars()) - if err != nil { - t.Fatalf("Eval() failed: %v", err) - } - if out.Value() != tc.out { - t.Errorf("Eval() got %v, wanted %v", out, tc.out) - } - }) - } -} - -func TestBadLocale(t *testing.T) { - _, err := cel.NewEnv(Strings(StringsLocale("bad-locale"))) - if err != nil { - if err.Error() != "failed to parse locale: language: subtag \"locale\" is well-formed but unknown" { - t.Errorf("expected error messaged to be \"failed to parse locale: language: subtag \"locale\" is well-formed but unknown\", got %q", err) - } - } else { - t.Error("expected NewEnv to fail during locale parsing") - } -} - -func TestLiteralOutput(t *testing.T) { - tests := []struct { - name string - formatLiteral string - expectedType string - }{ - { - name: "map literal support", - formatLiteral: `{"key1": b"xyz", false: [11, 12, 13, timestamp("2019-10-12T07:20:50.52Z")], 42: {uint(64): 2.7}, "key5": type(int), "key2": duration("2h"), "key4": true, "key3": 2.71828, "null": null}`, - expectedType: `map`, - }, - { - name: "list literal support", - formatLiteral: `["abc", 3.14, uint(32), b"def", null, type(string), duration("7m"), [9, 8, 7, 6], timestamp("2023-02-03T23:31:20Z")]`, - expectedType: `list`, - }, - } - for _, tt := range tests { - parseAndEval := func(expr string, t *testing.T) (ref.Val, error) { - env, err := cel.NewEnv(Strings()) - if err != nil { - t.Fatalf("cel.NewEnv(Strings()) failed: %v", err) - } - parsedAst, issues := env.Parse(expr) - if issues.Err() != nil { - t.Fatalf("env.Parse(%v) failed: %v", expr, issues.Err()) - } - checkedAst, issues := env.Check(parsedAst) - if issues.Err() != nil { - t.Fatalf("env.Check(%v) failed: %v", expr, issues.Err()) - } - program, err := env.Program(checkedAst) - if err != nil { - t.Fatal(err) - } - out, _, err := program.Eval(cel.NoVars()) - return out, err - } - t.Run(tt.name, func(t *testing.T) { - expr := fmt.Sprintf(`"%%s".format([%s])`, tt.formatLiteral) - literalVal, err := parseAndEval(expr, t) - if err != nil { - t.Fatalf("program.Eval failed: %v", err) - } - out, err := parseAndEval(literalVal.Value().(string), t) - if err != nil { - t.Fatalf("literal evaluation failed: %v", err) - } - if out.Type().TypeName() != tt.expectedType { - t.Errorf("expected literal to evaluate to type %s, got %s", tt.expectedType, out.Type().TypeName()) - } - equivalentVal, err := parseAndEval(literalVal.Value().(string)+" == "+tt.formatLiteral, t) - if err != nil { - t.Fatalf("equality evaluation failed: %v:", err) - } - if equivalentVal.Type().TypeName() != "bool" { - t.Errorf("expected equality expression to evaluation to type bool, got %s", equivalentVal.Type().TypeName()) - } - equivalent := equivalentVal.Value().(bool) - if !equivalent { - t.Errorf("%q (observed) and %q (expected) not considered equivalent", literalVal.Value().(string), tt.formatLiteral) - } - }) - } -} - func mustParseDuration(s string) time.Duration { d, err := time.ParseDuration(s) if err != nil { diff --git a/go.mod b/go.mod index ae23e9ee6..8bf321c5a 100644 --- a/go.mod +++ b/go.mod @@ -5,10 +5,9 @@ go 1.21.1 toolchain go1.23.0 require ( - cel.dev/expr v0.19.1 + cel.dev/expr v0.21.2 github.com/antlr4-go/antlr/v4 v4.13.0 github.com/stoewer/go-strcase v1.2.0 - golang.org/x/text v0.16.0 google.golang.org/genproto/googleapis/api v0.0.0-20240826202546-f6391c0de4c7 google.golang.org/protobuf v1.34.2 gopkg.in/yaml.v3 v3.0.1 @@ -16,5 +15,6 @@ require ( require ( golang.org/x/exp v0.0.0-20230515195305-f3d0a9c9a5cc // indirect + golang.org/x/text v0.22.0 google.golang.org/genproto/googleapis/rpc v0.0.0-20240826202546-f6391c0de4c7 // indirect ) diff --git a/go.sum b/go.sum index b518e1a53..69a046f2b 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -cel.dev/expr v0.19.1 h1:NciYrtDRIR0lNCnH1LFJegdjspNx9fI59O7TWcua/W4= -cel.dev/expr v0.19.1/go.mod h1:MrpN08Q+lEBs+bGYdLxxHkZoUSsCp0nSKTs0nTymJgw= +cel.dev/expr v0.21.2 h1:o+Wj235dy4gFYlYin3JsMpp3EEfMrPm/6tdoyjT98S0= +cel.dev/expr v0.21.2/go.mod h1:MrpN08Q+lEBs+bGYdLxxHkZoUSsCp0nSKTs0nTymJgw= github.com/antlr4-go/antlr/v4 v4.13.0 h1:lxCg3LAv+EUK6t1i0y1V6/SLeUi0eKEKdhQAlS8TVTI= github.com/antlr4-go/antlr/v4 v4.13.0/go.mod h1:pfChB/xh/Unjila75QW7+VU4TSnWnnk9UTnmpPaOR2g= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= @@ -15,8 +15,8 @@ github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= golang.org/x/exp v0.0.0-20230515195305-f3d0a9c9a5cc h1:mCRnTeVUjcrhlRmO0VK8a6k6Rrf6TF9htwo2pJVSjIU= golang.org/x/exp v0.0.0-20230515195305-f3d0a9c9a5cc/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= -golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= -golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= +golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM= +golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY= google.golang.org/genproto/googleapis/api v0.0.0-20240826202546-f6391c0de4c7 h1:YcyjlL1PRr2Q17/I0dPk2JmYS5CDXfcdb2Z3YRioEbw= google.golang.org/genproto/googleapis/api v0.0.0-20240826202546-f6391c0de4c7/go.mod h1:OCdP9MfskevB/rbYvHTsXTtKC+3bHWajPdoKgjcYkfo= google.golang.org/genproto/googleapis/rpc v0.0.0-20240826202546-f6391c0de4c7 h1:2035KHhUv+EpyB+hWgJnaWKJOdX1E95w2S8Rr4uWKTs= diff --git a/repl/evaluator.go b/repl/evaluator.go index c26b09c62..6befc46a2 100644 --- a/repl/evaluator.go +++ b/repl/evaluator.go @@ -1002,13 +1002,7 @@ func (e *Evaluator) Process(cmd Cmder) (string, bool, error) { if ok { return fmt.Sprintf("Unknown %v", unknown), false, nil } - v, err := ext.FormatString(val, "") - - if err != nil { - // Default format if type is unsupported by ext.Strings formatter. - return fmt.Sprintf("%v : %s", val.Value(), t), false, nil - } - return fmt.Sprintf("%s : %s", v, t), false, nil + return fmt.Sprintf("%s : %s", types.Format(val), t), false, nil } case *letVarCmd: var err error diff --git a/repl/evaluator_test.go b/repl/evaluator_test.go index dbe117724..5a82c2fa5 100644 --- a/repl/evaluator_test.go +++ b/repl/evaluator_test.go @@ -535,7 +535,7 @@ func TestProcess(t *testing.T) { expr: "1u + 2u", }, }, - wantText: "3 : uint", + wantText: "3u : uint", wantExit: false, wantError: false, }, @@ -546,7 +546,7 @@ func TestProcess(t *testing.T) { expr: `'a' + r'b\1'`, }, }, - wantText: `ab\1 : string`, + wantText: `"ab\\1" : string`, wantExit: false, wantError: false, }, @@ -557,7 +557,7 @@ func TestProcess(t *testing.T) { expr: `['abc', 123, 3.14, duration('2m')]`, }, }, - wantText: `["abc", 123, 3.140000, duration("120s")] : list(dyn)`, + wantText: `["abc", 123, 3.14, duration("120s")] : list(dyn)`, wantExit: false, wantError: false, }, @@ -568,7 +568,7 @@ func TestProcess(t *testing.T) { expr: `{1: 123, 2: 3.14, 3: duration('2m'), 4: b'123'}`, }, }, - wantText: `{1:123, 2:3.140000, 3:duration("120s"), 4:b"123"} : map(int, dyn)`, + wantText: `{1: 123, 2: 3.14, 3: duration("120s"), 4: b"\061\062\063"} : map(int, dyn)`, wantExit: false, wantError: false, }, @@ -619,7 +619,7 @@ func TestProcess(t *testing.T) { expr: "optional.none().orValue('default')", }, }, - wantText: "default : string", + wantText: "\"default\" : string", wantExit: false, wantError: false, }, @@ -637,7 +637,7 @@ func TestProcess(t *testing.T) { expr: "'test'.substring(2)", }, }, - wantText: "st : string", + wantText: "\"st\" : string", wantExit: false, wantError: false, }, @@ -691,7 +691,7 @@ func TestProcess(t *testing.T) { expr: "base64.encode(b'hello')", }, }, - wantText: "aGVsbG8= : string", + wantText: "\"aGVsbG8=\" : string", wantExit: false, wantError: false, }, @@ -825,7 +825,7 @@ func TestProcess(t *testing.T) { expr: "AttributeContext.Request{host: 'www.example.com'}", }, }, - wantText: `host:"www.example.com" : google.rpc.context.AttributeContext.Request`, + wantText: `google.rpc.context.AttributeContext.Request{host: "www.example.com"} : google.rpc.context.AttributeContext.Request`, wantExit: false, wantError: false, }, diff --git a/repl/go.mod b/repl/go.mod index 01c9db4ec..12e764df3 100644 --- a/repl/go.mod +++ b/repl/go.mod @@ -3,7 +3,7 @@ module github.com/google/cel-go/repl go 1.21.1 require ( - cel.dev/expr v0.18.0 + cel.dev/expr v0.21.2 github.com/antlr4-go/antlr/v4 v4.13.0 github.com/chzyer/readline v1.5.1 github.com/google/cel-go v0.0.0-00010101000000-000000000000 @@ -16,7 +16,7 @@ require ( github.com/stoewer/go-strcase v1.3.0 // indirect golang.org/x/exp v0.0.0-20230515195305-f3d0a9c9a5cc // indirect golang.org/x/sys v0.21.0 // indirect - golang.org/x/text v0.16.0 // indirect + golang.org/x/text v0.22.0 // indirect ) replace github.com/google/cel-go => ../. diff --git a/repl/go.sum b/repl/go.sum index e9660f26b..96c2f193b 100644 --- a/repl/go.sum +++ b/repl/go.sum @@ -27,8 +27,8 @@ golang.org/x/exp v0.0.0-20230515195305-f3d0a9c9a5cc/go.mod h1:V1LtkGg67GoY2N1AnL golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= -golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= +golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM= +golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY= google.golang.org/genproto/googleapis/api v0.0.0-20240826202546-f6391c0de4c7 h1:YcyjlL1PRr2Q17/I0dPk2JmYS5CDXfcdb2Z3YRioEbw= google.golang.org/genproto/googleapis/api v0.0.0-20240826202546-f6391c0de4c7/go.mod h1:OCdP9MfskevB/rbYvHTsXTtKC+3bHWajPdoKgjcYkfo= google.golang.org/genproto/googleapis/rpc v0.0.0-20240826202546-f6391c0de4c7 h1:2035KHhUv+EpyB+hWgJnaWKJOdX1E95w2S8Rr4uWKTs= diff --git a/vendor/cel.dev/expr/README.md b/vendor/cel.dev/expr/README.md index 7930c0b75..42d67f87c 100644 --- a/vendor/cel.dev/expr/README.md +++ b/vendor/cel.dev/expr/README.md @@ -69,5 +69,3 @@ For more detail, see: * [Language Definition](doc/langdef.md) Released under the [Apache License](LICENSE). - -Disclaimer: This is not an official Google product. diff --git a/vendor/golang.org/x/text/LICENSE b/vendor/golang.org/x/text/LICENSE index 6a66aea5e..2a7cf70da 100644 --- a/vendor/golang.org/x/text/LICENSE +++ b/vendor/golang.org/x/text/LICENSE @@ -1,4 +1,4 @@ -Copyright (c) 2009 The Go Authors. All rights reserved. +Copyright 2009 The Go Authors. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -10,7 +10,7 @@ notice, this list of conditions and the following disclaimer. copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. - * Neither the name of Google Inc. nor the names of its + * Neither the name of Google LLC nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. diff --git a/vendor/golang.org/x/text/internal/catmsg/codec.go b/vendor/golang.org/x/text/internal/catmsg/codec.go index 49c9fc978..547802b0f 100644 --- a/vendor/golang.org/x/text/internal/catmsg/codec.go +++ b/vendor/golang.org/x/text/internal/catmsg/codec.go @@ -257,7 +257,7 @@ func (d *Decoder) setError(err error) { // Language returns the language in which the message is being rendered. // // The destination language may be a child language of the language used for -// encoding. For instance, a decoding language of "pt-PT"" is consistent with an +// encoding. For instance, a decoding language of "pt-PT" is consistent with an // encoding language of "pt". func (d *Decoder) Language() language.Tag { return d.tag } diff --git a/vendor/modules.txt b/vendor/modules.txt index b7fb68df9..620a2e524 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -1,4 +1,4 @@ -# cel.dev/expr v0.19.1 +# cel.dev/expr v0.21.2 ## explicit; go 1.21.1 cel.dev/expr # github.com/antlr4-go/antlr/v4 v4.13.0 @@ -11,7 +11,7 @@ github.com/stoewer/go-strcase ## explicit; go 1.20 golang.org/x/exp/constraints golang.org/x/exp/slices -# golang.org/x/text v0.16.0 +# golang.org/x/text v0.22.0 ## explicit; go 1.18 golang.org/x/text/feature/plural golang.org/x/text/internal From dc6468be0c7c8f3c327193ee90a538db0cc292ab Mon Sep 17 00:00:00 2001 From: "zev.ac" Date: Wed, 26 Mar 2025 04:02:40 +0530 Subject: [PATCH 19/46] Add a new compiler tool which can be used to compile CEL expressions and policies using serialized environment (#1143) Support for bundling CEL environments into a library suitable for command line tools. --- WORKSPACE | 4 +- common/env/env.go | 2 +- ext/extension_option_factory_test.go | 3 + go.mod | 2 +- go.sum | 4 +- policy/BUILD.bazel | 10 +- policy/testdata/k8s/config.yaml | 4 + tools/compiler/BUILD.bazel | 75 +++ tools/compiler/compiler.go | 534 ++++++++++++++++++ tools/compiler/compiler_test.go | 492 ++++++++++++++++ tools/compiler/testdata/config.yaml | 84 +++ .../compiler/testdata/custom_policy.celpolicy | 25 + .../testdata/custom_policy_config.yaml | 18 + tools/go.mod | 22 + tools/go.sum | 37 ++ vendor/cel.dev/expr/.bazelversion | 2 +- vendor/cel.dev/expr/MODULE.bazel | 26 +- vendor/cel.dev/expr/cloudbuild.yaml | 2 +- vendor/modules.txt | 2 +- 19 files changed, 1327 insertions(+), 21 deletions(-) create mode 100644 tools/compiler/BUILD.bazel create mode 100644 tools/compiler/compiler.go create mode 100644 tools/compiler/compiler_test.go create mode 100644 tools/compiler/testdata/config.yaml create mode 100644 tools/compiler/testdata/custom_policy.celpolicy create mode 100644 tools/compiler/testdata/custom_policy_config.yaml create mode 100644 tools/go.mod create mode 100644 tools/go.sum diff --git a/WORKSPACE b/WORKSPACE index b52b8319a..f566d7d09 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -101,8 +101,8 @@ go_repository( go_repository( name = "dev_cel_expr", importpath = "cel.dev/expr", - sum = "h1:o+Wj235dy4gFYlYin3JsMpp3EEfMrPm/6tdoyjT98S0=", - version = "v0.21.2", + sum = "h1:xoFEsNh972Yzey8N9TCPx2nDvMN7TMhQEzxLuj/iRrI=", + version = "v0.22.1", ) # local_repository( diff --git a/common/env/env.go b/common/env/env.go index 07294c696..8e57c42e3 100644 --- a/common/env/env.go +++ b/common/env/env.go @@ -836,7 +836,7 @@ func (td *TypeDesc) AsCELType(tp types.Provider) (*types.Type, error) { } } -// SerializeTypeDesc converts *types.Type to a serialized format TypeDesc +// SerializeTypeDesc converts a CEL native *types.Type to a serializable TypeDesc. func SerializeTypeDesc(t *types.Type) *TypeDesc { typeName := t.TypeName() if t.Kind() == types.TypeParamKind { diff --git a/ext/extension_option_factory_test.go b/ext/extension_option_factory_test.go index f721bb6bf..573155baa 100644 --- a/ext/extension_option_factory_test.go +++ b/ext/extension_option_factory_test.go @@ -61,6 +61,9 @@ func TestExtensionOptionFactoryValidBindingsExtension(t *testing.T) { t.Fatalf("ExtensionOptionFactory(%s) returned invalid extension", e.Name) } cfg, err := en.ToConfig("test config") + if err != nil { + t.Fatalf("ToConfig(%s) returned error: %v", e.Name, err) + } if len(cfg.Extensions) != 1 || cfg.Extensions[0].Name != "cel.lib.ext.cel.bindings" || cfg.Extensions[0].Version != "latest" { t.Fatalf("ExtensionOptionFactory(%s) returned invalid extension", e.Name) } diff --git a/go.mod b/go.mod index 8bf321c5a..4108e1724 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.21.1 toolchain go1.23.0 require ( - cel.dev/expr v0.21.2 + cel.dev/expr v0.22.1 github.com/antlr4-go/antlr/v4 v4.13.0 github.com/stoewer/go-strcase v1.2.0 google.golang.org/genproto/googleapis/api v0.0.0-20240826202546-f6391c0de4c7 diff --git a/go.sum b/go.sum index 69a046f2b..062b316c3 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -cel.dev/expr v0.21.2 h1:o+Wj235dy4gFYlYin3JsMpp3EEfMrPm/6tdoyjT98S0= -cel.dev/expr v0.21.2/go.mod h1:MrpN08Q+lEBs+bGYdLxxHkZoUSsCp0nSKTs0nTymJgw= +cel.dev/expr v0.22.1 h1:xoFEsNh972Yzey8N9TCPx2nDvMN7TMhQEzxLuj/iRrI= +cel.dev/expr v0.22.1/go.mod h1:MrpN08Q+lEBs+bGYdLxxHkZoUSsCp0nSKTs0nTymJgw= github.com/antlr4-go/antlr/v4 v4.13.0 h1:lxCg3LAv+EUK6t1i0y1V6/SLeUi0eKEKdhQAlS8TVTI= github.com/antlr4-go/antlr/v4 v4.13.0/go.mod h1:pfChB/xh/Unjila75QW7+VU4TSnWnnk9UTnmpPaOR2g= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= diff --git a/policy/BUILD.bazel b/policy/BUILD.bazel index 875f523c3..f058f1ab9 100644 --- a/policy/BUILD.bazel +++ b/policy/BUILD.bazel @@ -15,7 +15,10 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") package( - default_visibility = ["//policy:__subpackages__"], + default_visibility = [ + "//policy:__subpackages__", + "//tools:__subpackages__" + ], licenses = ["notice"], ) @@ -67,3 +70,8 @@ go_test( "@in_gopkg_yaml_v3//:go_default_library", ], ) + +filegroup( + name = "k8s_policy_testdata", + srcs = glob(["testdata/k8s/*"]), +) \ No newline at end of file diff --git a/policy/testdata/k8s/config.yaml b/policy/testdata/k8s/config.yaml index 15a32b535..5a2cb3290 100644 --- a/policy/testdata/k8s/config.yaml +++ b/policy/testdata/k8s/config.yaml @@ -14,6 +14,10 @@ name: k8s extensions: + - name: "optional" + version: "latest" + - name: "bindings" + version: "latest" - name: "strings" version: 2 variables: diff --git a/tools/compiler/BUILD.bazel b/tools/compiler/BUILD.bazel new file mode 100644 index 000000000..0c3e4080b --- /dev/null +++ b/tools/compiler/BUILD.bazel @@ -0,0 +1,75 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], +) + +go_library( + name = "go_default_library", + srcs = [ + "compiler.go", + ], + importpath = "github.com/google/cel-go/tools/compiler", + deps = [ + "//cel:go_default_library", + "//common:go_default_library", + "//common/env:go_default_library", + "//common/types:go_default_library", + "//ext:go_default_library", + "//policy:go_default_library", + "@dev_cel_expr//:expr", + "@dev_cel_expr//conformance:go_default_library", + "@in_gopkg_yaml_v3//:go_default_library", + "@io_bazel_rules_go//go/runfiles", + "@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library", + "@org_golang_google_protobuf//encoding/prototext:go_default_library", + "@org_golang_google_protobuf//proto:go_default_library", + "@org_golang_google_protobuf//reflect/protoreflect:go_default_library", + "@org_golang_google_protobuf//types/descriptorpb:go_default_library", + ], +) + +filegroup( + name = "compiler_testdata", + srcs = glob(["testdata/**"]), +) + +go_test( + name = "go_default_test", + size = "small", + srcs = [ + "compiler_test.go", + ], + data = [ + ":compiler_testdata", + "//policy:k8s_policy_testdata", + ], + embed = [":go_default_library"], + deps = [ + "//cel:go_default_library", + "//common/decls:go_default_library", + "//common/env:go_default_library", + "//common/types:go_default_library", + "//ext:go_default_library", + "//policy:go_default_library", + "@dev_cel_expr//:expr", + "@dev_cel_expr//conformance:go_default_library", + "@in_gopkg_yaml_v3//:go_default_library", + "@org_golang_google_protobuf//types/known/structpb:go_default_library", + ], +) diff --git a/tools/compiler/compiler.go b/tools/compiler/compiler.go new file mode 100644 index 000000000..fd590ecf3 --- /dev/null +++ b/tools/compiler/compiler.go @@ -0,0 +1,534 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package compiler exposes a standard way to set up a compiler which can be used for CEL +// expressions and policies. +package compiler + +import ( + "fmt" + "os" + "path/filepath" + + "gopkg.in/yaml.v3" + + "github.com/google/cel-go/cel" + "github.com/google/cel-go/common" + "github.com/google/cel-go/common/env" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/ext" + "github.com/google/cel-go/policy" + "google.golang.org/protobuf/encoding/prototext" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + + celpb "cel.dev/expr" + configpb "cel.dev/expr/conformance" + exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" + descpb "google.golang.org/protobuf/types/descriptorpb" +) + +// FileFormat represents the format of the file being loaded. +type FileFormat int + +const ( + // Unspecified is used when the file format is not determined. + Unspecified FileFormat = iota + 1 + // BinaryProto is used for a binary proto file. + BinaryProto + // TextProto is used for a text proto file. + TextProto + // TextYAML is used for a YAML file. + TextYAML + // CELString is used for a CEL string expression defined in a file with .cel extension. + CELString + // CELPolicy is used for a CEL policy file with .celpolicy extension. + CELPolicy +) + +// PolicyMetadataEnvOption represents a function which accepts a policy metadata map and returns an +// environment option used to extend the CEL environment. +// +// The policy metadata map is generally produced as a byproduct of parsing the policy and it can +// be optionally customised by providing a custom policy parser. +type PolicyMetadataEnvOption func(map[string]any) cel.EnvOption + +// Compiler interface is used to set up a compiler with the following capabilities: +// - create a CEL environment +// - create a policy parser +// - fetch policy compiler options +// - fetch policy environment options +// +// Note: This compiler is not the same as the CEL expression compiler, rather it provides an +// abstraction layer which can create the different components needed for parsing and compiling CEL +// expressions and policies. +type Compiler interface { + // CreateEnv creates a singleton CEL environment with the configured environment options. + CreateEnv() (*cel.Env, error) + // CreatePolicyParser creates a policy parser using the optionally configured parser options. + CreatePolicyParser() (*policy.Parser, error) + // PolicyCompilerOptions returns the policy compiler options. + PolicyCompilerOptions() []policy.CompilerOption + // PolicyMetadataEnvOptions returns the policy metadata environment options. + PolicyMetadataEnvOptions() []PolicyMetadataEnvOption +} + +type compiler struct { + envOptions []cel.EnvOption + policyParserOptions []policy.ParserOption + policyCompilerOptions []policy.CompilerOption + policyMetadataEnvOptions []PolicyMetadataEnvOption + env *cel.Env +} + +// NewCompiler creates a new compiler with a set of functional options. +func NewCompiler(opts ...any) (Compiler, error) { + c := &compiler{ + envOptions: []cel.EnvOption{}, + policyParserOptions: []policy.ParserOption{}, + policyCompilerOptions: []policy.CompilerOption{}, + policyMetadataEnvOptions: []PolicyMetadataEnvOption{}, + } + for _, opt := range opts { + switch opt := opt.(type) { + case cel.EnvOption: + c.envOptions = append(c.envOptions, opt) + case policy.ParserOption: + c.policyParserOptions = append(c.policyParserOptions, opt) + case policy.CompilerOption: + c.policyCompilerOptions = append(c.policyCompilerOptions, opt) + case PolicyMetadataEnvOption: + c.policyMetadataEnvOptions = append(c.policyMetadataEnvOptions, opt) + default: + return nil, fmt.Errorf("unsupported compiler option: %v", opt) + } + } + return c, nil +} + +// CreateEnv creates a singleton CEL environment with the configured environment options. +func (c *compiler) CreateEnv() (*cel.Env, error) { + if c.env != nil { + return c.env, nil + } + env, err := cel.NewCustomEnv(c.envOptions...) + if err != nil { + return nil, err + } + c.env = env + return c.env, nil +} + +// CreatePolicyParser creates a policy parser using the optionally configured parser options. +func (c *compiler) CreatePolicyParser() (*policy.Parser, error) { + return policy.NewParser(c.policyParserOptions...) +} + +// PolicyCompilerOptions returns the policy compiler options configured in the compiler. +func (c *compiler) PolicyCompilerOptions() []policy.CompilerOption { + return c.policyCompilerOptions +} + +// PolicyMetadataEnvOptions returns the policy metadata environment options configured in the compiler. +func (c *compiler) PolicyMetadataEnvOptions() []PolicyMetadataEnvOption { + return c.policyMetadataEnvOptions +} + +func loadFile(path string) ([]byte, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("failed to read file %q: %v", path, err) + } + return data, err +} + +func loadProtoFile(path string, format FileFormat, out protoreflect.ProtoMessage) error { + unmarshaller := proto.Unmarshal + if format == TextProto { + unmarshaller = prototext.Unmarshal + } + data, err := loadFile(path) + if err != nil { + return err + } + return unmarshaller(data, out) +} + +func inferFileFormat(path string) FileFormat { + extension := filepath.Ext(path) + switch extension { + case ".textproto": + return TextProto + case ".yaml": + return TextYAML + case ".binarypb", ".fds": + return BinaryProto + case ".cel": + return CELString + case ".celpolicy": + return CELPolicy + default: + return Unspecified + } +} + +// EnvironmentFile returns an EnvOption which loads a serialized CEL environment from a file. +// The file must be in one of the following formats: +// - Textproto +// - YAML +// - Binarypb +func EnvironmentFile(path string) cel.EnvOption { + return func(e *cel.Env) (*cel.Env, error) { + format := inferFileFormat(path) + if format != TextProto && format != TextYAML && format != BinaryProto { + return nil, fmt.Errorf("file extension must be one of .textproto, .yaml, .binarypb: found %v", format) + } + var envConfig *env.Config + var fileDescriptorSet *descpb.FileDescriptorSet + switch format { + case TextProto, BinaryProto: + pbEnv := &configpb.Environment{} + var err error + if err = loadProtoFile(path, format, pbEnv); err != nil { + return nil, err + } + envConfig, fileDescriptorSet, err = envProtoToConfig(pbEnv) + if err != nil { + return nil, err + } + case TextYAML: + envConfig = &env.Config{} + data, err := loadFile(path) + if err != nil { + return nil, err + } + err = yaml.Unmarshal(data, envConfig) + if err != nil { + return nil, fmt.Errorf("yaml.Unmarshal failed to map CEL environment: %w", err) + } + default: + return nil, fmt.Errorf("unsupported format: %v, file: %s", format, path) + } + var envOpts []cel.EnvOption + if fileDescriptorSet != nil { + envOpts = append(envOpts, cel.TypeDescs(fileDescriptorSet)) + } + envOpts = append(envOpts, cel.FromConfig(envConfig, ext.ExtensionOptionFactory)) + var err error + e, err = e.Extend(envOpts...) + if err != nil { + return nil, fmt.Errorf("e.Extend() with env options %v failed: %w", envOpts, err) + } + return e, nil + } +} + +func envProtoToConfig(pbEnv *configpb.Environment) (*env.Config, *descpb.FileDescriptorSet, error) { + if pbEnv == nil { + return nil, nil, fmt.Errorf("proto environment is not set") + } + envConfig := env.NewConfig(pbEnv.GetName()) + envConfig.Description = pbEnv.GetDescription() + envConfig.SetContainer(pbEnv.GetContainer()) + for _, imp := range pbEnv.GetImports() { + envConfig.AddImports(env.NewImport(imp.GetName())) + } + stdLib, err := envToStdLib(pbEnv) + if err != nil { + return nil, nil, err + } + envConfig.SetStdLib(stdLib) + extensions := make([]*env.Extension, 0, len(pbEnv.GetExtensions())) + for _, extension := range pbEnv.GetExtensions() { + extensions = append(extensions, &env.Extension{Name: extension.GetName(), Version: extension.GetVersion()}) + } + envConfig.AddExtensions(extensions...) + if contextVariable := pbEnv.GetContextVariable(); contextVariable != nil { + envConfig.SetContextVariable(env.NewContextVariable(contextVariable.GetTypeName())) + } + functions, variables, err := protoDeclToFunctionsAndVariables(pbEnv.GetDeclarations()) + if err != nil { + return nil, nil, err + } + envConfig.AddFunctions(functions...) + envConfig.AddVariables(variables...) + validators, err := envToValidators(pbEnv) + if err != nil { + return nil, nil, err + } + envConfig.AddValidators(validators...) + features, err := envToFeatures(pbEnv) + if err != nil { + return nil, nil, err + } + envConfig.AddFeatures(features...) + fileDescriptorSet := pbEnv.GetMessageTypeExtension() + return envConfig, fileDescriptorSet, nil +} + +func envToFeatures(pbEnv *configpb.Environment) ([]*env.Feature, error) { + features := make([]*env.Feature, 0, len(pbEnv.GetFeatures())+1) + for _, feature := range pbEnv.GetFeatures() { + features = append(features, env.NewFeature(feature.GetName(), feature.GetEnabled())) + } + if pbEnv.GetEnableMacroCallTracking() { + features = append(features, env.NewFeature("cel.feature.macro_call_tracking", true)) + } + return features, nil +} + +func envToValidators(pbEnv *configpb.Environment) ([]*env.Validator, error) { + validators := make([]*env.Validator, 0, len(pbEnv.GetValidators())) + for _, pbValidator := range pbEnv.GetValidators() { + validator := env.NewValidator(pbValidator.GetName()) + config := map[string]any{} + for k, v := range pbValidator.GetConfig() { + val := types.DefaultTypeAdapter.NativeToValue(v) + config[k] = val + } + validator.SetConfig(config) + validators = append(validators, validator) + } + return validators, nil +} + +func protoDeclToFunctionsAndVariables(declarations []*celpb.Decl) ([]*env.Function, []*env.Variable, error) { + functions := make([]*env.Function, 0, len(declarations)) + variables := make([]*env.Variable, 0, len(declarations)) + for _, decl := range declarations { + switch decl.GetDeclKind().(type) { + case *celpb.Decl_Function: + fn, err := protoDeclToFunction(decl) + if err != nil { + return nil, nil, fmt.Errorf("protoDeclToFunction(%s) failed to create function: %w", decl.GetName(), err) + } + functions = append(functions, fn) + case *celpb.Decl_Ident: + t, err := types.ProtoAsType(decl.GetIdent().GetType()) + if err != nil { + return nil, nil, fmt.Errorf("types.ProtoAsType(%s) failed to create type: %w", decl.GetIdent().GetType(), err) + } + variables = append(variables, env.NewVariable(decl.GetName(), env.SerializeTypeDesc(t))) + } + } + return functions, variables, nil +} + +func envToStdLib(pbEnv *configpb.Environment) (*env.LibrarySubset, error) { + stdLib := env.NewLibrarySubset() + pbEnvStdLib := pbEnv.GetStdlib() + if pbEnvStdLib == nil { + if pbEnv.GetDisableStandardCelDeclarations() { + stdLib.SetDisabled(true) + return stdLib, nil + } + return nil, nil + } + if !stdLib.Disabled { + stdLib.SetDisabled(pbEnvStdLib.GetDisabled()) + } + stdLib.SetDisableMacros(pbEnvStdLib.GetDisableMacros()) + stdLib.AddIncludedMacros(pbEnvStdLib.GetIncludeMacros()...) + stdLib.AddExcludedMacros(pbEnvStdLib.GetExcludeMacros()...) + if pbEnvStdLib.GetIncludeFunctions() != nil { + for _, includeFn := range pbEnvStdLib.GetIncludeFunctions() { + if includeFn.GetFunction() != nil { + fn, err := protoDeclToFunction(includeFn) + if err != nil { + return nil, err + } + stdLib.AddIncludedFunctions(fn) + } else { + return nil, fmt.Errorf("IncludeFunctions must specify a function decl") + } + } + } + if pbEnvStdLib.GetExcludeFunctions() != nil { + for _, excludeFn := range pbEnvStdLib.GetExcludeFunctions() { + if excludeFn.GetFunction() != nil { + fn, err := protoDeclToFunction(excludeFn) + if err != nil { + return nil, err + } + stdLib.AddExcludedFunctions(fn) + } else { + return nil, fmt.Errorf("ExcludeFunctions must specify a function decl") + } + } + } + return stdLib, nil +} + +func protoDeclToFunction(decl *celpb.Decl) (*env.Function, error) { + declFn := decl.GetFunction() + if declFn == nil { + return nil, nil + } + overloads := make([]*env.Overload, 0, len(declFn.GetOverloads())) + for _, o := range declFn.GetOverloads() { + args := make([]*env.TypeDesc, 0, len(o.GetParams())) + for _, p := range o.GetParams() { + t, err := types.ProtoAsType(p) + if err != nil { + return nil, err + } + args = append(args, env.SerializeTypeDesc(t)) + } + res, err := types.ProtoAsType(o.GetResultType()) + if err != nil { + return nil, err + } + ret := env.SerializeTypeDesc(res) + if o.IsInstanceFunction { + overloads = append(overloads, env.NewMemberOverload(o.GetOverloadId(), args[0], args[1:], ret)) + } else { + overloads = append(overloads, env.NewOverload(o.GetOverloadId(), args, ret)) + } + } + return env.NewFunction(decl.GetName(), overloads...), nil +} + +// TypeDescriptorSetFile returns an EnvOption which loads type descriptors from a file. +// The file must be in binary format. +func TypeDescriptorSetFile(path string) cel.EnvOption { + return func(e *cel.Env) (*cel.Env, error) { + format := inferFileFormat(path) + if format != BinaryProto { + return nil, fmt.Errorf("type descriptor must be in binary format") + } + fds := &descpb.FileDescriptorSet{} + if err := loadProtoFile(path, BinaryProto, fds); err != nil { + return nil, err + } + var err error + e, err = e.Extend(cel.TypeDescs(fds)) + if err != nil { + return nil, fmt.Errorf("e.Extend() with type descriptor set %v failed: %w", fds, err) + } + return e, nil + } +} + +// InputExpression is an interface for an expression which can be compiled into a CEL AST and return +// an optional policy metadata map. +type InputExpression interface { + // CreateAST creates a CEL AST from the input expression using the provided compiler. + CreateAST(Compiler) (*cel.Ast, map[string]any, error) +} + +// CompiledExpression is an InputExpression which loads a CheckedExpr from a file. +type CompiledExpression struct { + Path string +} + +// CreateAST creates a CEL AST from a checked expression file. +// The file must be in one of the following formats: +// - Binarypb +// - Textproto +func (c *CompiledExpression) CreateAST(_ Compiler) (*cel.Ast, map[string]any, error) { + var expr exprpb.CheckedExpr + format := inferFileFormat(c.Path) + if format != BinaryProto && format != TextProto { + return nil, nil, fmt.Errorf("file extension must be .binarypb or .textproto: found %v", format) + } + if err := loadProtoFile(c.Path, format, &expr); err != nil { + return nil, nil, err + } + return cel.CheckedExprToAst(&expr), nil, nil +} + +// FileExpression is an InputExpression which loads a CEL expression or policy from a file. +type FileExpression struct { + Path string +} + +// CreateAST creates a CEL AST from a file using the provided compiler: +// - All policy metadata options as executed using the policy metadata map to extend the +// environment. +// - All policy compiler options are passed on to compile the parsed policy. +// +// The file must be in one of the following formats: +// - .cel: CEL string expression +// - .celpolicy: CEL policy +func (f *FileExpression) CreateAST(compiler Compiler) (*cel.Ast, map[string]any, error) { + e, err := compiler.CreateEnv() + if err != nil { + return nil, nil, err + } + data, err := loadFile(f.Path) + if err != nil { + return nil, nil, err + } + format := inferFileFormat(f.Path) + switch format { + case CELString: + src := common.NewStringSource(string(data), f.Path) + ast, iss := e.CompileSource(src) + if iss.Err() != nil { + return nil, nil, fmt.Errorf("e.CompileSource(%q) failed: %w", src.Content(), iss.Err()) + } + return ast, nil, nil + case CELPolicy, TextYAML: + src := policy.ByteSource(data, f.Path) + parser, err := compiler.CreatePolicyParser() + if err != nil { + return nil, nil, err + } + p, iss := parser.Parse(src) + if iss.Err() != nil { + return nil, nil, fmt.Errorf("parser.Parse(%q) failed: %w", src.Content(), iss.Err()) + } + policyMetadata := clonePolicyMetadata(p) + for _, opt := range compiler.PolicyMetadataEnvOptions() { + if e, err = e.Extend(opt(policyMetadata)); err != nil { + return nil, nil, fmt.Errorf("e.Extend() with metadata option failed: %w", err) + } + } + ast, iss := policy.Compile(e, p, compiler.PolicyCompilerOptions()...) + if iss.Err() != nil { + return nil, nil, fmt.Errorf("policy.Compile(%q) failed: %w", src.Content(), iss.Err()) + } + return ast, policyMetadata, nil + default: + return nil, nil, fmt.Errorf("unsupported file format: %v", format) + } +} + +func clonePolicyMetadata(p *policy.Policy) map[string]any { + metadataKeys := p.MetadataKeys() + metadata := make(map[string]any, len(metadataKeys)) + for _, key := range metadataKeys { + value, _ := p.Metadata(key) + metadata[key] = value + } + return metadata +} + +// RawExpression is an InputExpression which loads a CEL expression from a string. +type RawExpression struct { + Value string +} + +// CreateAST creates a CEL AST from a raw CEL expression using the provided compiler. +func (r *RawExpression) CreateAST(compiler Compiler) (*cel.Ast, map[string]any, error) { + e, err := compiler.CreateEnv() + if err != nil { + return nil, nil, err + } + ast, iss := e.Compile(r.Value) + if iss.Err() != nil { + return nil, nil, fmt.Errorf("e.Compile(%q) failed: %w", r.Value, iss.Err()) + } + return ast, nil, nil +} diff --git a/tools/compiler/compiler_test.go b/tools/compiler/compiler_test.go new file mode 100644 index 000000000..d0c9ad0be --- /dev/null +++ b/tools/compiler/compiler_test.go @@ -0,0 +1,492 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package compiler + +import ( + "reflect" + "testing" + + "github.com/google/cel-go/cel" + "github.com/google/cel-go/common/decls" + "github.com/google/cel-go/common/env" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/ext" + "github.com/google/cel-go/policy" + "gopkg.in/yaml.v3" + + celpb "cel.dev/expr" + configpb "cel.dev/expr/conformance" + structpb "google.golang.org/protobuf/types/known/structpb" +) + +func TestEnvironmentFileCompareTextprotoAndYAML(t *testing.T) { + t.Run("compare textproto and yaml environment files", func(t *testing.T) { + pbEnv := testEnvProto() + protoConfig, err := configFromEnvProto(t, pbEnv) + if err != nil { + t.Fatalf("configFromEnvProto(%v) failed: %v", pbEnv, err) + } + config, err := parseEnv(t, "yaml_config", "testdata/config.yaml") + if err != nil { + t.Fatalf("parseEnv(%s) failed: %v", "testdata/config.yaml", err) + } + if protoConfig.Container != config.Container { + t.Fatalf("Container got %q, wanted %q", protoConfig.Container, config.Container) + } + if !reflect.DeepEqual(protoConfig.Imports, config.Imports) { + t.Fatalf("Imports got %v, wanted %v", protoConfig.Imports, config.Imports) + } + if !reflect.DeepEqual(protoConfig.StdLib, config.StdLib) { + t.Fatalf("StdLib got %v, wanted %v", protoConfig.StdLib, config.StdLib) + } + if len(protoConfig.Extensions) != len(config.Extensions) { + t.Fatalf("Extensions count got %d, wanted %d", len(protoConfig.Extensions), len(config.Extensions)) + } + for _, protoConfigExt := range protoConfig.Extensions { + found := false + for _, configExt := range config.Extensions { + if reflect.DeepEqual(protoConfigExt, configExt) { + found = true + break + } + } + if !found { + t.Fatalf("Extensions got %v, wanted %v", protoConfig.Extensions, config.Extensions) + } + } + if !reflect.DeepEqual(protoConfig.ContextVariable, config.ContextVariable) { + t.Fatalf("ContextVariable got %v, wanted %v", protoConfig.ContextVariable, config.ContextVariable) + } + if len(protoConfig.Variables) != len(config.Variables) { + t.Fatalf("Variables count got %d, wanted %d", len(protoConfig.Variables), len(config.Variables)) + } else { + for i, v := range protoConfig.Variables { + for j, p := range v.TypeDesc.Params { + if p.TypeName == "google.protobuf.Any" && + config.Variables[i].TypeDesc.Params[j].TypeName == "dyn" { + p.TypeName = "dyn" + } + } + if !reflect.DeepEqual(v, config.Variables[i]) { + t.Fatalf("Variables[%d] not equal, got %v, wanted %v", i, v, config.Variables[i]) + } + } + } + if len(protoConfig.Functions) != len(config.Functions) { + t.Fatalf("Functions count got %d, wanted %d", len(protoConfig.Functions), len(config.Functions)) + } else { + for i, f := range protoConfig.Functions { + if !reflect.DeepEqual(f, config.Functions[i]) { + t.Fatalf("Functions[%d] not equal, got %v, wanted %v", i, f, config.Functions[i]) + } + } + } + if len(protoConfig.Features) != len(config.Features) { + t.Fatalf("Features count got %d, wanted %d", len(protoConfig.Features), len(config.Features)) + } else { + for i, f := range protoConfig.Features { + if !reflect.DeepEqual(f, config.Features[i]) { + t.Fatalf("Features[%d] not equal, got %v, wanted %v", i, f, config.Features[i]) + } + } + } + if len(protoConfig.Validators) != len(config.Validators) { + t.Fatalf("Validators count got %d, wanted %d", len(protoConfig.Validators), len(config.Validators)) + } else { + for i, v := range protoConfig.Validators { + if !reflect.DeepEqual(v, config.Validators[i]) { + t.Fatalf("Validators[%d] not equal, got %v, wanted %v", i, v, config.Validators[i]) + } + } + } + }) +} + +func testEnvProto() *configpb.Environment { + return &configpb.Environment{ + Name: "test-environment", + Description: "Test environment", + Container: "google.expr", + Imports: []*configpb.Environment_Import{ + {Name: "google.expr.proto3.test.TestAllTypes"}, + }, + Stdlib: &configpb.LibrarySubset{ + IncludeMacros: []string{"has", "exists"}, + IncludeFunctions: []*celpb.Decl{ + { + Name: "_==_", + DeclKind: &celpb.Decl_Function{ + Function: &celpb.Decl_FunctionDecl{ + Overloads: []*celpb.Decl_FunctionDecl_Overload{ + { + OverloadId: "equals", + Params: []*celpb.Type{ + { + TypeKind: &celpb.Type_Primitive{ + Primitive: celpb.Type_STRING, + }, + }, + { + TypeKind: &celpb.Type_Primitive{ + Primitive: celpb.Type_STRING, + }, + }, + }, + ResultType: &celpb.Type{ + TypeKind: &celpb.Type_Primitive{ + Primitive: celpb.Type_BOOL, + }, + }, + }, + }, + }, + }, + }, + { + Name: "_||_", + DeclKind: &celpb.Decl_Function{ + Function: &celpb.Decl_FunctionDecl{ + Overloads: []*celpb.Decl_FunctionDecl_Overload{ + { + OverloadId: "logical_or", + Params: []*celpb.Type{ + { + TypeKind: &celpb.Type_Primitive{ + Primitive: celpb.Type_BOOL, + }, + }, + { + TypeKind: &celpb.Type_Primitive{ + Primitive: celpb.Type_BOOL, + }, + }, + }, + ResultType: &celpb.Type{ + TypeKind: &celpb.Type_Primitive{ + Primitive: celpb.Type_BOOL, + }, + }, + }, + }, + }, + }, + }, + }, + }, + Extensions: []*configpb.Extension{ + { + Name: "optional", + Version: "latest", + }, + { + Name: "lists", + Version: "latest", + }, + { + Name: "sets", + Version: "latest", + }, + }, + Declarations: []*celpb.Decl{ + { + Name: "destination.ip", + DeclKind: &celpb.Decl_Ident{ + Ident: &celpb.Decl_IdentDecl{ + Type: &celpb.Type{ + TypeKind: &celpb.Type_Primitive{ + Primitive: celpb.Type_STRING, + }, + }, + }, + }, + }, + { + Name: "origin.ip", + DeclKind: &celpb.Decl_Ident{ + Ident: &celpb.Decl_IdentDecl{ + Type: &celpb.Type{ + TypeKind: &celpb.Type_Primitive{ + Primitive: celpb.Type_STRING, + }, + }, + }, + }, + }, + { + Name: "spec.restricted_destinations", + DeclKind: &celpb.Decl_Ident{ + Ident: &celpb.Decl_IdentDecl{ + Type: &celpb.Type{ + TypeKind: &celpb.Type_ListType_{ + ListType: &celpb.Type_ListType{ + ElemType: &celpb.Type{ + TypeKind: &celpb.Type_Primitive{ + Primitive: celpb.Type_STRING, + }, + }, + }, + }, + }, + }, + }, + }, + { + Name: "spec.origin", + DeclKind: &celpb.Decl_Ident{ + Ident: &celpb.Decl_IdentDecl{ + Type: &celpb.Type{ + TypeKind: &celpb.Type_Primitive{ + Primitive: celpb.Type_STRING, + }, + }, + }, + }, + }, + { + Name: "request", + DeclKind: &celpb.Decl_Ident{ + Ident: &celpb.Decl_IdentDecl{ + Type: &celpb.Type{ + TypeKind: &celpb.Type_MapType_{ + MapType: &celpb.Type_MapType{ + KeyType: &celpb.Type{ + TypeKind: &celpb.Type_Primitive{ + Primitive: celpb.Type_STRING, + }, + }, + ValueType: &celpb.Type{ + TypeKind: &celpb.Type_WellKnown{ + WellKnown: celpb.Type_ANY, + }, + }, + }, + }, + }, + }, + }, + }, + { + Name: "resource", + DeclKind: &celpb.Decl_Ident{ + Ident: &celpb.Decl_IdentDecl{ + Type: &celpb.Type{ + TypeKind: &celpb.Type_MapType_{ + MapType: &celpb.Type_MapType{ + KeyType: &celpb.Type{ + TypeKind: &celpb.Type_Primitive{ + Primitive: celpb.Type_STRING, + }, + }, + ValueType: &celpb.Type{ + TypeKind: &celpb.Type_WellKnown{ + WellKnown: celpb.Type_ANY, + }, + }, + }, + }, + }, + }, + }, + }, + { + Name: "locationCode", + DeclKind: &celpb.Decl_Function{ + Function: &celpb.Decl_FunctionDecl{ + Overloads: []*celpb.Decl_FunctionDecl_Overload{ + { + OverloadId: "locationCode_string", + Params: []*celpb.Type{ + { + TypeKind: &celpb.Type_Primitive{ + Primitive: celpb.Type_STRING, + }, + }, + }, + ResultType: &celpb.Type{ + TypeKind: &celpb.Type_Primitive{ + Primitive: celpb.Type_STRING, + }, + }, + }, + }, + }, + }, + }, + }, + Validators: []*configpb.Validator{ + {Name: "cel.validator.duration"}, + { + Name: "cel.validator.nesting_comprehension_limit", + Config: map[string]*structpb.Value{ + "limits": structpb.NewNumberValue(2), + }, + }, + }, + Features: []*configpb.Feature{ + { + Name: "cel.feature.macro_call_tracking", + Enabled: true, + }, + }, + } +} + +func configFromEnvProto(t *testing.T, pbEnv *configpb.Environment) (*env.Config, error) { + t.Helper() + envConfig, fileDescriptorSet, err := envProtoToConfig(pbEnv) + if err != nil { + return nil, err + } + var envOpts []cel.EnvOption + if fileDescriptorSet != nil { + envOpts = append(envOpts, cel.TypeDescs(fileDescriptorSet)) + } + envOpts = append(envOpts, cel.FromConfig(envConfig, ext.ExtensionOptionFactory)) + return envOptionToConfig(t, envConfig.Name, envOpts...) +} + +func parseEnv(t *testing.T, name, path string) (*env.Config, error) { + t.Helper() + opts := EnvironmentFile(path) + return envOptionToConfig(t, name, opts) +} + +func envOptionToConfig(t *testing.T, name string, opts ...cel.EnvOption) (*env.Config, error) { + t.Helper() + e, err := cel.NewCustomEnv(opts...) + if err != nil { + return nil, err + } + conf, err := e.ToConfig(name) + if err != nil { + return nil, err + } + return conf, nil +} + +func TestFileExpressionCustomPolicyParser(t *testing.T) { + t.Run("test file expression custom policy parser", func(t *testing.T) { + envOpt := EnvironmentFile("../../policy/testdata/k8s/config.yaml") + parserOpt := policy.ParserOption(func(p *policy.Parser) (*policy.Parser, error) { + p.TagVisitor = policy.K8sTestTagHandler() + return p, nil + }) + compilerOpts := []any{envOpt, parserOpt} + compiler, err := NewCompiler(compilerOpts...) + if err != nil { + t.Fatalf("NewCompiler() failed: %v", err) + } + policyFile := &FileExpression{ + Path: "../../policy/testdata/k8s/policy.yaml", + } + k8sAst, _, err := policyFile.CreateAST(compiler) + if err != nil { + t.Fatalf("CreateAST() failed: %v", err) + } + if k8sAst == nil { + t.Fatalf("CreateAST() returned nil ast") + } + }) +} + +func TestFileExpressionPolicyMetadataOptions(t *testing.T) { + t.Run("test file expression policy metadata options", func(t *testing.T) { + envOpt := EnvironmentFile("testdata/custom_policy_config.yaml") + parserOpt := policy.ParserOption(func(p *policy.Parser) (*policy.Parser, error) { + p.TagVisitor = customTagHandler{TagVisitor: policy.DefaultTagVisitor()} + return p, nil + }) + policyMetadataOpt := PolicyMetadataEnvOption(ParsePolicyVariables) + compilerOpts := []any{envOpt, parserOpt, policyMetadataOpt} + compiler, err := NewCompiler(compilerOpts...) + if err != nil { + t.Fatalf("NewCompiler() failed: %v", err) + } + policyFile := &FileExpression{ + Path: "testdata/custom_policy.celpolicy", + } + ast, _, err := policyFile.CreateAST(compiler) + if err != nil { + t.Fatalf("CreateAST() failed: %v", err) + } + if ast == nil { + t.Fatalf("CreateAST() returned nil ast") + } + }) +} + +func ParsePolicyVariables(metadata map[string]any) cel.EnvOption { + variables := []*decls.VariableDecl{} + for n, t := range metadata { + variables = append(variables, decls.NewVariable(n, parseCustomPolicyVariableType(t.(string)))) + } + return cel.VariableDecls(variables...) +} + +func parseCustomPolicyVariableType(t string) *types.Type { + switch t { + case "int": + return types.IntType + case "string": + return types.StringType + default: + return types.UnknownType + } +} + +type variableType struct { + VariableName string `yaml:"variable_name"` + VariableType string `yaml:"variable_type"` +} + +type customTagHandler struct { + policy.TagVisitor +} + +func (customTagHandler) PolicyTag(ctx policy.ParserContext, id int64, tagName string, node *yaml.Node, p *policy.Policy) { + switch tagName { + case "variable_types": + varList := []*variableType{} + if err := node.Decode(&varList); err != nil { + ctx.ReportErrorAtID(id, "invalid yaml variable_types node: %v, error: %w", node, err) + return + } + for _, v := range varList { + p.SetMetadata(v.VariableName, v.VariableType) + } + default: + ctx.ReportErrorAtID(id, "unsupported policy tag: %s", tagName) + } +} + +func TestRawExpressionCreateAst(t *testing.T) { + t.Run("test raw expression create ast", func(t *testing.T) { + envOpt := EnvironmentFile("testdata/config.yaml") + compiler, err := NewCompiler(envOpt) + if err != nil { + t.Fatalf("NewCompiler() failed: %v", err) + } + rawExpr := &RawExpression{ + Value: "locationCode(destination.ip)==locationCode(origin.ip)", + } + ast, _, err := rawExpr.CreateAST(compiler) + if err != nil { + t.Fatalf("CreateAST() failed: %v", err) + } + if ast == nil { + t.Fatalf("CreateAST() returned nil ast") + } + }) +} diff --git a/tools/compiler/testdata/config.yaml b/tools/compiler/testdata/config.yaml new file mode 100644 index 000000000..a63a859ad --- /dev/null +++ b/tools/compiler/testdata/config.yaml @@ -0,0 +1,84 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: "test-environment" +description: "Test environment" +container: "google.expr" +imports: +- name: "google.expr.proto3.test.TestAllTypes" +stdlib: + include_macros: + - has + - exists + include_functions: + - name: "_==_" + overloads: + - id: equals + args: + - type_name: "string" + - type_name: "string" + return: + type_name: "bool" + - name: "_||_" + overloads: + - id: logical_or + args: + - type_name: "bool" + - type_name: "bool" + return: + type_name: "bool" +extensions: +- name: "optional" + version: "latest" +- name: "lists" + version: "latest" +- name: "sets" + version: "latest" +variables: +- name: "destination.ip" + type_name: "string" +- name: "origin.ip" + type_name: "string" +- name: "spec.restricted_destinations" + type_name: "list" + params: + - type_name: "string" +- name: "spec.origin" + type_name: "string" +- name: "request" + type_name: "map" + params: + - type_name: "string" + - type_name: "dyn" +- name: "resource" + type_name: "map" + params: + - type_name: "string" + - type_name: "dyn" +functions: +- name: "locationCode" + overloads: + - id: "locationCode_string" + args: + - type_name: "string" + return: + type_name: "string" +validators: +- name: cel.validator.duration +- name: cel.validator.nesting_comprehension_limit + config: + limit: 2 +features: +- name: cel.feature.macro_call_tracking + enabled: true diff --git a/tools/compiler/testdata/custom_policy.celpolicy b/tools/compiler/testdata/custom_policy.celpolicy new file mode 100644 index 000000000..663fcf0a7 --- /dev/null +++ b/tools/compiler/testdata/custom_policy.celpolicy @@ -0,0 +1,25 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: "custom_policy" +variable_types: + - variable_name: "variable1" + variable_type: "int" + - variable_name: "variable2" + variable_type: "string" +rule: + match: + - condition: | + variable1 == 1 || variable2 == "known" + output: "true" diff --git a/tools/compiler/testdata/custom_policy_config.yaml b/tools/compiler/testdata/custom_policy_config.yaml new file mode 100644 index 000000000..460fab4b4 --- /dev/null +++ b/tools/compiler/testdata/custom_policy_config.yaml @@ -0,0 +1,18 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: "custom_policy_config" +extensions: +- name: "optional" + version: "latest" diff --git a/tools/go.mod b/tools/go.mod new file mode 100644 index 000000000..392efac8f --- /dev/null +++ b/tools/go.mod @@ -0,0 +1,22 @@ +module github.com/google/cel-go/tools + +go 1.23.0 + +require ( + cel.dev/expr v0.22.1 + github.com/google/cel-go v0.22.0 + github.com/google/cel-go/policy v0.0.0-20250311174852-f5ea07b389a1 + google.golang.org/genproto/googleapis/api v0.0.0-20250311190419-81fb87f6b8bf + google.golang.org/protobuf v1.36.5 + gopkg.in/yaml.v3 v3.0.1 +) + +require ( + github.com/antlr4-go/antlr/v4 v4.13.1 // indirect + github.com/stoewer/go-strcase v1.3.0 // indirect + golang.org/x/exp v0.0.0-20240823005443-9b4947da3948 // indirect + golang.org/x/text v0.22.0 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20250311190419-81fb87f6b8bf // indirect +) + +replace github.com/google/cel-go => ../. diff --git a/tools/go.sum b/tools/go.sum new file mode 100644 index 000000000..b34becfc8 --- /dev/null +++ b/tools/go.sum @@ -0,0 +1,37 @@ +cel.dev/expr v0.22.1 h1:xoFEsNh972Yzey8N9TCPx2nDvMN7TMhQEzxLuj/iRrI= +cel.dev/expr v0.22.1/go.mod h1:MrpN08Q+lEBs+bGYdLxxHkZoUSsCp0nSKTs0nTymJgw= +github.com/antlr4-go/antlr/v4 v4.13.1 h1:SqQKkuVZ+zWkMMNkjy5FZe5mr5WURWnlpmOuzYWrPrQ= +github.com/antlr4-go/antlr/v4 v4.13.1/go.mod h1:GKmUxMtwp6ZgGwZSva4eWPC5mS6vUAmOABFgjdkM7Nw= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/google/cel-go/policy v0.0.0-20250311174852-f5ea07b389a1 h1:jT/04RYwo++S9tvHggXWuAqvnc2Pi0BTHYsZYVOoMOs= +github.com/google/cel-go/policy v0.0.0-20250311174852-f5ea07b389a1/go.mod h1:dgvqy3CzFx17CBMkL0s1hd0r1+rEQOo85tDpr0g6Dp4= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stoewer/go-strcase v1.3.0 h1:g0eASXYtp+yvN9fK8sH94oCIk0fau9uV1/ZdJ0AVEzs= +github.com/stoewer/go-strcase v1.3.0/go.mod h1:fAH5hQ5pehh+j3nZfvwdk2RgEgQjAoM8wodgtPmh1xo= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +golang.org/x/exp v0.0.0-20240823005443-9b4947da3948 h1:kx6Ds3MlpiUHKj7syVnbp57++8WpuKPcR5yjLBjvLEA= +golang.org/x/exp v0.0.0-20240823005443-9b4947da3948/go.mod h1:akd2r19cwCdwSwWeIdzYQGa/EZZyqcOdwWiwj5L5eKQ= +golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM= +golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY= +google.golang.org/genproto/googleapis/api v0.0.0-20250311190419-81fb87f6b8bf h1:BdIVRm+fyDUn8lrZLPSlBCfM/YKDwUBYgDoLv9+DYo0= +google.golang.org/genproto/googleapis/api v0.0.0-20250311190419-81fb87f6b8bf/go.mod h1:jbe3Bkdp+Dh2IrslsFCklNhweNTBgSYanP1UXhJDhKg= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250311190419-81fb87f6b8bf h1:dHDlF3CWxQkefK9IJx+O8ldY0gLygvrlYRBNbPqDWuY= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250311190419-81fb87f6b8bf/go.mod h1:LuRYeWDFV6WOn90g357N17oMCaxpgCnbi/44qJvDn2I= +google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM= +google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/vendor/cel.dev/expr/.bazelversion b/vendor/cel.dev/expr/.bazelversion index 26bc914a3..13c50892b 100644 --- a/vendor/cel.dev/expr/.bazelversion +++ b/vendor/cel.dev/expr/.bazelversion @@ -1,2 +1,2 @@ -7.0.1 +7.3.2 # Keep this pinned version in parity with cel-go diff --git a/vendor/cel.dev/expr/MODULE.bazel b/vendor/cel.dev/expr/MODULE.bazel index 9794266f5..c0a631316 100644 --- a/vendor/cel.dev/expr/MODULE.bazel +++ b/vendor/cel.dev/expr/MODULE.bazel @@ -13,12 +13,24 @@ bazel_dep( ) bazel_dep( name = "googleapis", - version = "0.0.0-20240819-fe8ba054a", + version = "0.0.0-20241220-5e258e33.bcr.1", repo_name = "com_google_googleapis", ) +bazel_dep( + name = "googleapis-cc", + version = "1.0.0", +) +bazel_dep( + name = "googleapis-java", + version = "1.0.0", +) +bazel_dep( + name = "googleapis-go", + version = "1.0.0", +) bazel_dep( name = "protobuf", - version = "26.0", + version = "27.0", repo_name = "com_google_protobuf", ) bazel_dep( @@ -27,7 +39,7 @@ bazel_dep( ) bazel_dep( name = "rules_go", - version = "0.49.0", + version = "0.50.1", repo_name = "io_bazel_rules_go", ) bazel_dep( @@ -50,14 +62,6 @@ python.toolchain( python_version = "3.11", ) -switched_rules = use_extension("@com_google_googleapis//:extensions.bzl", "switched_rules") -switched_rules.use_languages( - cc = True, - go = True, - java = True, -) -use_repo(switched_rules, "com_google_googleapis_imports") - go_sdk = use_extension("@io_bazel_rules_go//go:extensions.bzl", "go_sdk") go_sdk.download(version = "1.21.1") diff --git a/vendor/cel.dev/expr/cloudbuild.yaml b/vendor/cel.dev/expr/cloudbuild.yaml index c40881f12..e3e533a04 100644 --- a/vendor/cel.dev/expr/cloudbuild.yaml +++ b/vendor/cel.dev/expr/cloudbuild.yaml @@ -1,5 +1,5 @@ steps: -- name: 'gcr.io/cloud-builders/bazel:7.0.1' +- name: 'gcr.io/cloud-builders/bazel:7.3.2' entrypoint: bazel args: ['build', '...'] id: bazel-build diff --git a/vendor/modules.txt b/vendor/modules.txt index 620a2e524..dfdf1bd13 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -1,4 +1,4 @@ -# cel.dev/expr v0.21.2 +# cel.dev/expr v0.22.1 ## explicit; go 1.21.1 cel.dev/expr # github.com/antlr4-go/antlr/v4 v4.13.0 From f6a27f78e87480744b0f363c8010be5f3fda048a Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Wed, 26 Mar 2025 18:31:28 -0700 Subject: [PATCH 20/46] Bzlmod configuration fixes (#1146) * Update bzlmod deps to fix conformance breakage * Update proto rules version * Update go build deps * Update the go toolchain version in GCB --- .bazelversion | 2 +- MODULE.bazel | 21 ++++++++++----------- cloudbuild.yaml | 2 +- codelab/go.mod | 5 +++-- codelab/go.sum | 10 ++++++---- conformance/BUILD.bazel | 1 + conformance/go.mod | 5 +++-- conformance/go.sum | 6 ++++-- go.mod | 2 +- policy/go.mod | 6 +++--- policy/go.sum | 8 ++++---- repl/appengine/go.mod | 2 +- repl/go.mod | 4 ++-- 13 files changed, 40 insertions(+), 34 deletions(-) diff --git a/.bazelversion b/.bazelversion index 9fe9ff9d9..eab246c06 100644 --- a/.bazelversion +++ b/.bazelversion @@ -1 +1 @@ -7.0.1 +7.3.2 diff --git a/MODULE.bazel b/MODULE.bazel index 54f1cac34..50096d0ea 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -13,32 +13,30 @@ bazel_dep( ) bazel_dep( name = "googleapis", - version = "0.0.0-20240819-fe8ba054a", + version = "0.0.0-20241220-5e258e33.bcr.1", repo_name = "com_google_googleapis", ) +bazel_dep( + name = "googleapis-go", + version = "1.0.0", +) bazel_dep( name = "protobuf", - version = "26.0", + version = "27.0", repo_name = "com_google_protobuf", ) bazel_dep( name = "rules_go", - version = "0.50.1", + version = "0.53.0", repo_name = "io_bazel_rules_go", ) bazel_dep( name = "rules_proto", - version = "6.0.0", -) - -switched_rules = use_extension("@com_google_googleapis//:extensions.bzl", "switched_rules") -switched_rules.use_languages( - go = True, + version = "7.0.2", ) -use_repo(switched_rules, "com_google_googleapis_imports") go_sdk = use_extension("@io_bazel_rules_go//go:extensions.bzl", "go_sdk") -go_sdk.download(version = "1.21.1") +go_sdk.download(version = "1.22.0") go_deps = use_extension("@bazel_gazelle//:extensions.bzl", "go_deps") go_deps.gazelle_default_attributes( @@ -53,6 +51,7 @@ go_deps.gazelle_override( "gazelle:go_generate_proto false", # Provide hints to gazelle about how includes and imports map to build targets "gazelle:resolve go cel.dev/expr @dev_cel_expr//:expr", + "gazelle:resolve go cel.dev/expr/conformance @dev_cel_expr//conformance:go_default_library", "gazelle:resolve proto go google/rpc/status.proto @org_golang_google_genproto_googleapis_rpc//status", "gazelle:resolve proto proto google/rpc/status.proto @googleapis//google/rpc:status_proto", ], diff --git a/cloudbuild.yaml b/cloudbuild.yaml index 4fd87abf9..32523f4ff 100644 --- a/cloudbuild.yaml +++ b/cloudbuild.yaml @@ -21,7 +21,7 @@ steps: # entrypoint: sh # # deploys folder of test results (with build id as folder name) to GCS # args: ['-c', 'gsutil cp -r $(cat _DATE) gs://cel-conformance/test-logs/'] - - name: 'golang:1.21.1' + - name: 'golang:1.22.0' # check the integrity of the vendor directory args: ['scripts/verify-vendor.sh'] - name: 'gcr.io/cloud-builders/bazel' diff --git a/codelab/go.mod b/codelab/go.mod index fb770bf22..58d84bb8b 100644 --- a/codelab/go.mod +++ b/codelab/go.mod @@ -1,6 +1,7 @@ module github.com/google/cel-go/codelab -go 1.21 +go 1.22.0 + toolchain go1.22.5 require ( @@ -11,7 +12,7 @@ require ( ) require ( - cel.dev/expr v0.19.1 // indirect + cel.dev/expr v0.22.1 // indirect github.com/antlr4-go/antlr/v4 v4.13.0 // indirect github.com/stoewer/go-strcase v1.2.0 // indirect golang.org/x/exp v0.0.0-20230515195305-f3d0a9c9a5cc // indirect diff --git a/codelab/go.sum b/codelab/go.sum index f956c29d3..8dbce6c33 100644 --- a/codelab/go.sum +++ b/codelab/go.sum @@ -1,5 +1,5 @@ -cel.dev/expr v0.19.1 h1:NciYrtDRIR0lNCnH1LFJegdjspNx9fI59O7TWcua/W4= -cel.dev/expr v0.19.1/go.mod h1:MrpN08Q+lEBs+bGYdLxxHkZoUSsCp0nSKTs0nTymJgw= +cel.dev/expr v0.22.1 h1:xoFEsNh972Yzey8N9TCPx2nDvMN7TMhQEzxLuj/iRrI= +cel.dev/expr v0.22.1/go.mod h1:MrpN08Q+lEBs+bGYdLxxHkZoUSsCp0nSKTs0nTymJgw= github.com/antlr4-go/antlr/v4 v4.13.0 h1:lxCg3LAv+EUK6t1i0y1V6/SLeUi0eKEKdhQAlS8TVTI= github.com/antlr4-go/antlr/v4 v4.13.0/go.mod h1:pfChB/xh/Unjila75QW7+VU4TSnWnnk9UTnmpPaOR2g= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= @@ -17,8 +17,8 @@ github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= golang.org/x/exp v0.0.0-20230515195305-f3d0a9c9a5cc h1:mCRnTeVUjcrhlRmO0VK8a6k6Rrf6TF9htwo2pJVSjIU= golang.org/x/exp v0.0.0-20230515195305-f3d0a9c9a5cc/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= -golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= -golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= +golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM= +golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY= google.golang.org/genproto/googleapis/api v0.0.0-20240826202546-f6391c0de4c7 h1:YcyjlL1PRr2Q17/I0dPk2JmYS5CDXfcdb2Z3YRioEbw= google.golang.org/genproto/googleapis/api v0.0.0-20240826202546-f6391c0de4c7/go.mod h1:OCdP9MfskevB/rbYvHTsXTtKC+3bHWajPdoKgjcYkfo= google.golang.org/genproto/googleapis/rpc v0.0.0-20240826202546-f6391c0de4c7 h1:2035KHhUv+EpyB+hWgJnaWKJOdX1E95w2S8Rr4uWKTs= @@ -28,3 +28,5 @@ google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWn gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/conformance/BUILD.bazel b/conformance/BUILD.bazel index 6b2e50bcf..50ec9c8d3 100644 --- a/conformance/BUILD.bazel +++ b/conformance/BUILD.bazel @@ -83,6 +83,7 @@ go_test( "//ext:go_default_library", "@com_github_google_go_cmp//cmp:go_default_library", "@dev_cel_expr//:expr", + "@dev_cel_expr//conformance:go_default_library", "@dev_cel_expr//conformance/test:go_default_library", "@dev_cel_expr//conformance/proto2:go_default_library", "@dev_cel_expr//conformance/proto3:go_default_library", diff --git a/conformance/go.mod b/conformance/go.mod index b666bfe67..115630be9 100644 --- a/conformance/go.mod +++ b/conformance/go.mod @@ -1,9 +1,9 @@ module github.com/google/cel-go/conformance -go 1.21.1 +go 1.22.0 require ( - cel.dev/expr v0.21.2 + cel.dev/expr v0.22.1 github.com/bazelbuild/rules_go v0.49.0 github.com/google/cel-go v0.21.0 github.com/google/go-cmp v0.6.0 @@ -14,6 +14,7 @@ require ( github.com/antlr4-go/antlr/v4 v4.13.0 // indirect github.com/stoewer/go-strcase v1.2.0 // indirect golang.org/x/exp v0.0.0-20230515195305-f3d0a9c9a5cc // indirect + golang.org/x/text v0.22.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240826202546-f6391c0de4c7 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240826202546-f6391c0de4c7 // indirect ) diff --git a/conformance/go.sum b/conformance/go.sum index cadf61fd9..9544b17a9 100644 --- a/conformance/go.sum +++ b/conformance/go.sum @@ -1,5 +1,5 @@ -cel.dev/expr v0.21.2 h1:o+Wj235dy4gFYlYin3JsMpp3EEfMrPm/6tdoyjT98S0= -cel.dev/expr v0.21.2/go.mod h1:MrpN08Q+lEBs+bGYdLxxHkZoUSsCp0nSKTs0nTymJgw= +cel.dev/expr v0.22.1 h1:xoFEsNh972Yzey8N9TCPx2nDvMN7TMhQEzxLuj/iRrI= +cel.dev/expr v0.22.1/go.mod h1:MrpN08Q+lEBs+bGYdLxxHkZoUSsCp0nSKTs0nTymJgw= github.com/antlr4-go/antlr/v4 v4.13.0 h1:lxCg3LAv+EUK6t1i0y1V6/SLeUi0eKEKdhQAlS8TVTI= github.com/antlr4-go/antlr/v4 v4.13.0/go.mod h1:pfChB/xh/Unjila75QW7+VU4TSnWnnk9UTnmpPaOR2g= github.com/bazelbuild/rules_go v0.49.0 h1:5vCbuvy8Q11g41lseGJDc5vxhDjJtfxr6nM/IC4VmqM= @@ -17,6 +17,8 @@ github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= golang.org/x/exp v0.0.0-20230515195305-f3d0a9c9a5cc h1:mCRnTeVUjcrhlRmO0VK8a6k6Rrf6TF9htwo2pJVSjIU= golang.org/x/exp v0.0.0-20230515195305-f3d0a9c9a5cc/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= +golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM= +golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY= google.golang.org/genproto/googleapis/api v0.0.0-20240826202546-f6391c0de4c7 h1:YcyjlL1PRr2Q17/I0dPk2JmYS5CDXfcdb2Z3YRioEbw= google.golang.org/genproto/googleapis/api v0.0.0-20240826202546-f6391c0de4c7/go.mod h1:OCdP9MfskevB/rbYvHTsXTtKC+3bHWajPdoKgjcYkfo= google.golang.org/genproto/googleapis/rpc v0.0.0-20240826202546-f6391c0de4c7 h1:2035KHhUv+EpyB+hWgJnaWKJOdX1E95w2S8Rr4uWKTs= diff --git a/go.mod b/go.mod index 4108e1724..9f089f4fd 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/google/cel-go -go 1.21.1 +go 1.22.0 toolchain go1.23.0 diff --git a/policy/go.mod b/policy/go.mod index 6c996e61a..410781f5d 100644 --- a/policy/go.mod +++ b/policy/go.mod @@ -1,6 +1,6 @@ module github.com/google/cel-go/policy -go 1.22 +go 1.22.0 require ( github.com/google/cel-go v0.22.0 @@ -9,11 +9,11 @@ require ( ) require ( - cel.dev/expr v0.19.1 // indirect + cel.dev/expr v0.22.1 // indirect github.com/antlr4-go/antlr/v4 v4.13.1 // indirect github.com/stoewer/go-strcase v1.3.0 // indirect golang.org/x/exp v0.0.0-20240823005443-9b4947da3948 // indirect - golang.org/x/text v0.17.0 // indirect + golang.org/x/text v0.22.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240826202546-f6391c0de4c7 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240826202546-f6391c0de4c7 // indirect ) diff --git a/policy/go.sum b/policy/go.sum index a0f0bfd91..8b4ac4221 100644 --- a/policy/go.sum +++ b/policy/go.sum @@ -1,5 +1,5 @@ -cel.dev/expr v0.19.1 h1:NciYrtDRIR0lNCnH1LFJegdjspNx9fI59O7TWcua/W4= -cel.dev/expr v0.19.1/go.mod h1:MrpN08Q+lEBs+bGYdLxxHkZoUSsCp0nSKTs0nTymJgw= +cel.dev/expr v0.22.1 h1:xoFEsNh972Yzey8N9TCPx2nDvMN7TMhQEzxLuj/iRrI= +cel.dev/expr v0.22.1/go.mod h1:MrpN08Q+lEBs+bGYdLxxHkZoUSsCp0nSKTs0nTymJgw= github.com/antlr4-go/antlr/v4 v4.13.1 h1:SqQKkuVZ+zWkMMNkjy5FZe5mr5WURWnlpmOuzYWrPrQ= github.com/antlr4-go/antlr/v4 v4.13.1/go.mod h1:GKmUxMtwp6ZgGwZSva4eWPC5mS6vUAmOABFgjdkM7Nw= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -20,8 +20,8 @@ github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKs github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= golang.org/x/exp v0.0.0-20240823005443-9b4947da3948 h1:kx6Ds3MlpiUHKj7syVnbp57++8WpuKPcR5yjLBjvLEA= golang.org/x/exp v0.0.0-20240823005443-9b4947da3948/go.mod h1:akd2r19cwCdwSwWeIdzYQGa/EZZyqcOdwWiwj5L5eKQ= -golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= -golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM= +golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY= google.golang.org/genproto/googleapis/api v0.0.0-20240826202546-f6391c0de4c7 h1:YcyjlL1PRr2Q17/I0dPk2JmYS5CDXfcdb2Z3YRioEbw= google.golang.org/genproto/googleapis/api v0.0.0-20240826202546-f6391c0de4c7/go.mod h1:OCdP9MfskevB/rbYvHTsXTtKC+3bHWajPdoKgjcYkfo= google.golang.org/genproto/googleapis/rpc v0.0.0-20240826202546-f6391c0de4c7 h1:2035KHhUv+EpyB+hWgJnaWKJOdX1E95w2S8Rr4uWKTs= diff --git a/repl/appengine/go.mod b/repl/appengine/go.mod index a130217ba..1bb0dc512 100644 --- a/repl/appengine/go.mod +++ b/repl/appengine/go.mod @@ -1,6 +1,6 @@ module github.com/google/cel-go/repl/appengine -go 1.21 +go 1.22.0 require github.com/google/cel-go/repl v0.0.0-20230406155237-b081aea03865 diff --git a/repl/go.mod b/repl/go.mod index 12e764df3..2ac63e68a 100644 --- a/repl/go.mod +++ b/repl/go.mod @@ -1,9 +1,9 @@ module github.com/google/cel-go/repl -go 1.21.1 +go 1.22.0 require ( - cel.dev/expr v0.21.2 + cel.dev/expr v0.22.1 github.com/antlr4-go/antlr/v4 v4.13.0 github.com/chzyer/readline v1.5.1 github.com/google/cel-go v0.0.0-00010101000000-000000000000 From e31356b4078a62abf5888033d57a5804fd9bd8fa Mon Sep 17 00:00:00 2001 From: "zev.ac" Date: Thu, 27 Mar 2025 09:34:36 +0530 Subject: [PATCH 21/46] Refactoring changes (#1145) --- policy/testdata/context_pb/config.yaml | 2 + tools/compiler/compiler.go | 2 +- tools/compiler/testdata/config.yaml | 116 +++++++++--------- .../testdata/custom_policy_config.yaml | 4 +- 4 files changed, 63 insertions(+), 61 deletions(-) diff --git a/policy/testdata/context_pb/config.yaml b/policy/testdata/context_pb/config.yaml index 53ea95425..e7804575c 100644 --- a/policy/testdata/context_pb/config.yaml +++ b/policy/testdata/context_pb/config.yaml @@ -15,6 +15,8 @@ name: "context_pb" container: "google.expr.proto3" extensions: + - name: "optional" + version: "latest" - name: "strings" version: 2 context_variable: diff --git a/tools/compiler/compiler.go b/tools/compiler/compiler.go index fd590ecf3..c2263e02f 100644 --- a/tools/compiler/compiler.go +++ b/tools/compiler/compiler.go @@ -454,7 +454,7 @@ type FileExpression struct { } // CreateAST creates a CEL AST from a file using the provided compiler: -// - All policy metadata options as executed using the policy metadata map to extend the +// - All policy metadata options are executed using the policy metadata map to extend the // environment. // - All policy compiler options are passed on to compile the parsed policy. // diff --git a/tools/compiler/testdata/config.yaml b/tools/compiler/testdata/config.yaml index a63a859ad..929427bc0 100644 --- a/tools/compiler/testdata/config.yaml +++ b/tools/compiler/testdata/config.yaml @@ -16,69 +16,69 @@ name: "test-environment" description: "Test environment" container: "google.expr" imports: -- name: "google.expr.proto3.test.TestAllTypes" + - name: "google.expr.proto3.test.TestAllTypes" stdlib: include_macros: - - has - - exists + - has + - exists include_functions: - - name: "_==_" - overloads: - - id: equals - args: - - type_name: "string" - - type_name: "string" - return: - type_name: "bool" - - name: "_||_" - overloads: - - id: logical_or - args: - - type_name: "bool" - - type_name: "bool" - return: - type_name: "bool" + - name: "_==_" + overloads: + - id: equals + args: + - type_name: "string" + - type_name: "string" + return: + type_name: "bool" + - name: "_||_" + overloads: + - id: logical_or + args: + - type_name: "bool" + - type_name: "bool" + return: + type_name: "bool" extensions: -- name: "optional" - version: "latest" -- name: "lists" - version: "latest" -- name: "sets" - version: "latest" + - name: "optional" + version: "latest" + - name: "lists" + version: "latest" + - name: "sets" + version: "latest" variables: -- name: "destination.ip" - type_name: "string" -- name: "origin.ip" - type_name: "string" -- name: "spec.restricted_destinations" - type_name: "list" - params: - - type_name: "string" -- name: "spec.origin" - type_name: "string" -- name: "request" - type_name: "map" - params: - - type_name: "string" - - type_name: "dyn" -- name: "resource" - type_name: "map" - params: - - type_name: "string" - - type_name: "dyn" + - name: "destination.ip" + type_name: "string" + - name: "origin.ip" + type_name: "string" + - name: "spec.restricted_destinations" + type_name: "list" + params: + - type_name: "string" + - name: "spec.origin" + type_name: "string" + - name: "request" + type_name: "map" + params: + - type_name: "string" + - type_name: "dyn" + - name: "resource" + type_name: "map" + params: + - type_name: "string" + - type_name: "dyn" functions: -- name: "locationCode" - overloads: - - id: "locationCode_string" - args: - - type_name: "string" - return: - type_name: "string" + - name: "locationCode" + overloads: + - id: "locationCode_string" + args: + - type_name: "string" + return: + type_name: "string" validators: -- name: cel.validator.duration -- name: cel.validator.nesting_comprehension_limit - config: - limit: 2 + - name: cel.validator.duration + - name: cel.validator.nesting_comprehension_limit + config: + limit: 2 features: -- name: cel.feature.macro_call_tracking - enabled: true + - name: cel.feature.macro_call_tracking + enabled: true diff --git a/tools/compiler/testdata/custom_policy_config.yaml b/tools/compiler/testdata/custom_policy_config.yaml index 460fab4b4..7b54a43da 100644 --- a/tools/compiler/testdata/custom_policy_config.yaml +++ b/tools/compiler/testdata/custom_policy_config.yaml @@ -14,5 +14,5 @@ name: "custom_policy_config" extensions: -- name: "optional" - version: "latest" + - name: "optional" + version: "latest" From 840f741df07cddadba6fdf0a3a8e54a6e026f1e2 Mon Sep 17 00:00:00 2001 From: Chuang Wang Date: Tue, 1 Apr 2025 23:49:40 -0700 Subject: [PATCH 22/46] Update NativeTypes doc to reflect how to enable cel tag (#1148) --- ext/native.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/native.go b/ext/native.go index 1c33def49..661984cbb 100644 --- a/ext/native.go +++ b/ext/native.go @@ -81,7 +81,7 @@ var ( // the time that it is invoked. // // There is also the possibility to rename the fields of native structs by setting the `cel` tag -// for fields you want to override. In order to enable this feature, pass in the `EnableStructTag` +// for fields you want to override. In order to enable this feature, pass in the `ParseStructTags(true)` // option. Here is an example to see it in action: // // ```go From 8514549dac19b496d5ec755dd4dd7132176ade42 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Wed, 2 Apr 2025 17:23:05 -0700 Subject: [PATCH 23/46] Additional comments and coverage for Activation methods (#1150) --- interpreter/activation.go | 1 + interpreter/activation_test.go | 41 ++++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/interpreter/activation.go b/interpreter/activation.go index 80105f4ff..dd40619ee 100644 --- a/interpreter/activation.go +++ b/interpreter/activation.go @@ -178,6 +178,7 @@ func (a *partActivation) AsPartialActivation() (PartialActivation, bool) { return a, true } +// AsPartialActivation walks the activation hierarchy and returns the first PartialActivation, if found. func AsPartialActivation(vars Activation) (PartialActivation, bool) { // Only internal activation instances may implement this interface if pv, ok := vars.(partialActivationConverter); ok { diff --git a/interpreter/activation_test.go b/interpreter/activation_test.go index fab4e4332..731313804 100644 --- a/interpreter/activation_test.go +++ b/interpreter/activation_test.go @@ -63,6 +63,25 @@ func TestActivation_ResolveLazy(t *testing.T) { } } +func TestActivation_ResolveLazyAny(t *testing.T) { + var v any + now := func() any { + if v == nil { + v = time.Now().Unix() + } + return v + } + a, _ := NewActivation(map[string]any{ + "now": now, + }) + first, _ := a.ResolveName("now") + second, _ := a.ResolveName("now") + if first != second { + t.Errorf("Got different second, "+ + "expected same as first: 1:%v 2:%v", first, second) + } +} + func TestHierarchicalActivation(t *testing.T) { // compose a parent with more properties than the child parent, _ := NewActivation(map[string]any{ @@ -89,3 +108,25 @@ func TestHierarchicalActivation(t *testing.T) { t.Error("Activation failed to resolve child value of 'c'") } } + +func TestAsPartialActivation(t *testing.T) { + // compose a parent with more properties than the child + parent, _ := NewPartialActivation(map[string]any{ + "a": types.String("world"), + "b": types.Int(-42), + }, NewAttributePattern("c")) + // compose the child such that it shadows the parent + child, _ := NewActivation(map[string]any{ + "d": types.String("universe"), + }) + combined := NewHierarchicalActivation(parent, child) + + // Resolve the shadowed child value. + if part, found := AsPartialActivation(combined); found { + if part != parent { + t.Errorf("AsPartialActivation() got %v, wanted %v", part, parent) + } + } else { + t.Error("AsPartialActivation() failed, did not find parent partial activation") + } +} From 8de5d323bc0730e7dae9c84c38469892a2fea61c Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Mon, 7 Apr 2025 12:33:51 -0700 Subject: [PATCH 24/46] Update type formatting for type params (#1154) --- common/types/types.go | 3 +++ common/types/types_test.go | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/common/types/types.go b/common/types/types.go index 3ed514093..78c77a9b5 100644 --- a/common/types/types.go +++ b/common/types/types.go @@ -399,6 +399,9 @@ func (t *Type) WithTraits(traits int) *Type { // String returns a human-readable definition of the type name. func (t *Type) String() string { + if t.Kind() == TypeParamKind { + return fmt.Sprintf("<%s>", t.DeclaredTypeName()) + } if len(t.Parameters()) == 0 { return t.DeclaredTypeName() } diff --git a/common/types/types_test.go b/common/types/types_test.go index 14e9a8dee..c91d7148a 100644 --- a/common/types/types_test.go +++ b/common/types/types_test.go @@ -83,7 +83,7 @@ func TestTypeString(t *testing.T) { }, { in: NewTypeParamType("T"), - out: "T", + out: "", }, // nil-safety tests { From 2c79b146fb757064493607feaac6ad671e975807 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Tue, 8 Apr 2025 12:08:18 -0700 Subject: [PATCH 25/46] Lightweight observable evaluation (#1151) * Lightweight observable evaluation Transitions from replanning expressions on each invocation to generating stateful metadata in a concurrency friendly manner --- cel/options.go | 2 +- cel/program.go | 204 +++++++++-------------------- interpreter/attributes_test.go | 17 ++- interpreter/interpretable.go | 74 ++++++++--- interpreter/interpreter.go | 188 +++++++++++++++++++------- interpreter/interpreter_test.go | 23 ++-- interpreter/planner.go | 48 ++++--- interpreter/prune_test.go | 2 +- interpreter/runtimecost.go | 225 +++++++++++++++++++++++--------- interpreter/runtimecost_test.go | 5 +- 10 files changed, 470 insertions(+), 318 deletions(-) diff --git a/cel/options.go b/cel/options.go index 33a6c5e76..fee67323c 100644 --- a/cel/options.go +++ b/cel/options.go @@ -404,7 +404,7 @@ type ProgramOption func(p *prog) (*prog, error) // InterpretableDecorators can be used to inspect, alter, or replace the Program plan. func CustomDecorator(dec interpreter.InterpretableDecorator) ProgramOption { return func(p *prog) (*prog, error) { - p.decorators = append(p.decorators, dec) + p.plannerOptions = append(p.plannerOptions, interpreter.CustomDecorator(dec)) return p, nil } } diff --git a/cel/program.go b/cel/program.go index 144d1f25a..fe151eef1 100644 --- a/cel/program.go +++ b/cel/program.go @@ -151,30 +151,17 @@ type prog struct { // Intermediate state used to configure the InterpretableDecorator set provided // to the initInterpretable call. - decorators []interpreter.InterpretableDecorator + plannerOptions []interpreter.PlannerOption regexOptimizations []*interpreter.RegexOptimization // Interpretable configured from an Ast and aggregate decorator set based on program options. interpretable interpreter.Interpretable + observable *interpreter.ObservableInterpretable callCostEstimator interpreter.ActualCostEstimator costOptions []interpreter.CostTrackerOption costLimit *uint64 } -func (p *prog) clone() *prog { - costOptsCopy := make([]interpreter.CostTrackerOption, len(p.costOptions)) - copy(costOptsCopy, p.costOptions) - - return &prog{ - Env: p.Env, - evalOpts: p.evalOpts, - defaultVars: p.defaultVars, - dispatcher: p.dispatcher, - interpreter: p.interpreter, - interruptCheckFrequency: p.interruptCheckFrequency, - } -} - // newProgram creates a program instance with an environment, an ast, and an optional list of // ProgramOption values. // @@ -186,10 +173,10 @@ func newProgram(e *Env, a *ast.AST, opts []ProgramOption) (Program, error) { // Ensure the default attribute factory is set after the adapter and provider are // configured. p := &prog{ - Env: e, - decorators: []interpreter.InterpretableDecorator{}, - dispatcher: disp, - costOptions: []interpreter.CostTrackerOption{}, + Env: e, + plannerOptions: []interpreter.PlannerOption{}, + dispatcher: disp, + costOptions: []interpreter.CostTrackerOption{}, } // Configure the program via the ProgramOption values. @@ -227,74 +214,71 @@ func newProgram(e *Env, a *ast.AST, opts []ProgramOption) (Program, error) { p.interpreter = interp // Translate the EvalOption flags into InterpretableDecorator instances. - decorators := make([]interpreter.InterpretableDecorator, len(p.decorators)) - copy(decorators, p.decorators) + plannerOptions := make([]interpreter.PlannerOption, len(p.plannerOptions)) + copy(plannerOptions, p.plannerOptions) // Enable interrupt checking if there's a non-zero check frequency if p.interruptCheckFrequency > 0 { - decorators = append(decorators, interpreter.InterruptableEval()) + plannerOptions = append(plannerOptions, interpreter.InterruptableEval()) } // Enable constant folding first. if p.evalOpts&OptOptimize == OptOptimize { - decorators = append(decorators, interpreter.Optimize()) + plannerOptions = append(plannerOptions, interpreter.Optimize()) p.regexOptimizations = append(p.regexOptimizations, interpreter.MatchesRegexOptimization) } // Enable regex compilation of constants immediately after folding constants. if len(p.regexOptimizations) > 0 { - decorators = append(decorators, interpreter.CompileRegexConstants(p.regexOptimizations...)) + plannerOptions = append(plannerOptions, interpreter.CompileRegexConstants(p.regexOptimizations...)) } // Enable exhaustive eval, state tracking and cost tracking last since they require a factory. if p.evalOpts&(OptExhaustiveEval|OptTrackState|OptTrackCost) != 0 { - factory := func(state interpreter.EvalState, costTracker *interpreter.CostTracker) (Program, error) { - costTracker.Estimator = p.callCostEstimator - costTracker.Limit = p.costLimit - for _, costOpt := range p.costOptions { - err := costOpt(costTracker) - if err != nil { - return nil, err - } - } - // Limit capacity to guarantee a reallocation when calling 'append(decs, ...)' below. This - // prevents the underlying memory from being shared between factory function calls causing - // undesired mutations. - decs := decorators[:len(decorators):len(decorators)] - var observers []interpreter.EvalObserver - - if p.evalOpts&(OptExhaustiveEval|OptTrackState) != 0 { - // EvalStateObserver is required for OptExhaustiveEval. - observers = append(observers, interpreter.EvalStateObserver(state)) - } - if p.evalOpts&OptTrackCost == OptTrackCost { - observers = append(observers, interpreter.CostObserver(costTracker)) - } - - // Enable exhaustive eval over a basic observer since it offers a superset of features. - if p.evalOpts&OptExhaustiveEval == OptExhaustiveEval { - decs = append(decs, interpreter.ExhaustiveEval(), interpreter.Observe(observers...)) - } else if len(observers) > 0 { - decs = append(decs, interpreter.Observe(observers...)) - } - - return p.clone().initInterpretable(a, decs) + costOptCount := len(p.costOptions) + if p.costLimit != nil { + costOptCount++ + } + costOpts := make([]interpreter.CostTrackerOption, 0, costOptCount) + costOpts = append(costOpts, p.costOptions...) + if p.costLimit != nil { + costOpts = append(costOpts, interpreter.CostTrackerLimit(*p.costLimit)) + } + trackerFactory := func() (*interpreter.CostTracker, error) { + return interpreter.NewCostTracker(p.callCostEstimator, costOpts...) + } + var observers []interpreter.PlannerOption + if p.evalOpts&(OptExhaustiveEval|OptTrackState) != 0 { + // EvalStateObserver is required for OptExhaustiveEval. + observers = append(observers, interpreter.EvalStateObserver()) + } + if p.evalOpts&OptTrackCost == OptTrackCost { + observers = append(observers, interpreter.CostObserver(interpreter.CostTrackerFactory(trackerFactory))) + } + // Enable exhaustive eval over a basic observer since it offers a superset of features. + if p.evalOpts&OptExhaustiveEval == OptExhaustiveEval { + plannerOptions = append(plannerOptions, + append([]interpreter.PlannerOption{interpreter.ExhaustiveEval()}, observers...)...) + } else if len(observers) > 0 { + plannerOptions = append(plannerOptions, observers...) } - return newProgGen(factory) } - return p.initInterpretable(a, decorators) + return p.initInterpretable(a, plannerOptions) } -func (p *prog) initInterpretable(a *ast.AST, decs []interpreter.InterpretableDecorator) (*prog, error) { +func (p *prog) initInterpretable(a *ast.AST, plannerOptions []interpreter.PlannerOption) (*prog, error) { // When the AST has been exprAST it contains metadata that can be used to speed up program execution. - interpretable, err := p.interpreter.NewInterpretable(a, decs...) + interpretable, err := p.interpreter.NewInterpretable(a, plannerOptions...) if err != nil { return nil, err } p.interpretable = interpretable + if oi, ok := interpretable.(*interpreter.ObservableInterpretable); ok { + p.observable = oi + } return p, nil } // Eval implements the Program interface method. -func (p *prog) Eval(input any) (v ref.Val, det *EvalDetails, err error) { +func (p *prog) Eval(input any) (out ref.Val, det *EvalDetails, err error) { // Configure error recovery for unexpected panics during evaluation. Note, the use of named // return values makes it possible to modify the error response during the recovery // function. @@ -322,12 +306,24 @@ func (p *prog) Eval(input any) (v ref.Val, det *EvalDetails, err error) { if p.defaultVars != nil { vars = interpreter.NewHierarchicalActivation(p.defaultVars, vars) } - v = p.interpretable.Eval(vars) + if p.observable != nil { + det = &EvalDetails{} + out = p.observable.ObserveEval(vars, func(observed any) { + switch o := observed.(type) { + case interpreter.EvalState: + det.state = o + case *interpreter.CostTracker: + det.costTracker = o + } + }) + } else { + out = p.interpretable.Eval(vars) + } // The output of an internal Eval may have a value (`v`) that is a types.Err. This step // translates the CEL value to a Go error response. This interface does not quite match the // RPC signature which allows for multiple errors to be returned, but should be sufficient. - if types.IsError(v) { - err = v.(*types.Err) + if types.IsError(out) { + err = out.(*types.Err) } return } @@ -355,88 +351,6 @@ func (p *prog) ContextEval(ctx context.Context, input any) (ref.Val, *EvalDetail return p.Eval(vars) } -// progFactory is a helper alias for marking a program creation factory function. -type progFactory func(interpreter.EvalState, *interpreter.CostTracker) (Program, error) - -// progGen holds a reference to a progFactory instance and implements the Program interface. -type progGen struct { - factory progFactory -} - -// newProgGen tests the factory object by calling it once and returns a factory-based Program if -// the test is successful. -func newProgGen(factory progFactory) (Program, error) { - // Test the factory to make sure that configuration errors are spotted at config - tracker, err := interpreter.NewCostTracker(nil) - if err != nil { - return nil, err - } - _, err = factory(interpreter.NewEvalState(), tracker) - if err != nil { - return nil, err - } - return &progGen{factory: factory}, nil -} - -// Eval implements the Program interface method. -func (gen *progGen) Eval(input any) (ref.Val, *EvalDetails, error) { - // The factory based Eval() differs from the standard evaluation model in that it generates a - // new EvalState instance for each call to ensure that unique evaluations yield unique stateful - // results. - state := interpreter.NewEvalState() - costTracker, err := interpreter.NewCostTracker(nil) - if err != nil { - return nil, nil, err - } - det := &EvalDetails{state: state, costTracker: costTracker} - - // Generate a new instance of the interpretable using the factory configured during the call to - // newProgram(). It is incredibly unlikely that the factory call will generate an error given - // the factory test performed within the Program() call. - p, err := gen.factory(state, costTracker) - if err != nil { - return nil, det, err - } - - // Evaluate the input, returning the result and the 'state' within EvalDetails. - v, _, err := p.Eval(input) - if err != nil { - return v, det, err - } - return v, det, nil -} - -// ContextEval implements the Program interface method. -func (gen *progGen) ContextEval(ctx context.Context, input any) (ref.Val, *EvalDetails, error) { - if ctx == nil { - return nil, nil, fmt.Errorf("context can not be nil") - } - // The factory based Eval() differs from the standard evaluation model in that it generates a - // new EvalState instance for each call to ensure that unique evaluations yield unique stateful - // results. - state := interpreter.NewEvalState() - costTracker, err := interpreter.NewCostTracker(nil) - if err != nil { - return nil, nil, err - } - det := &EvalDetails{state: state, costTracker: costTracker} - - // Generate a new instance of the interpretable using the factory configured during the call to - // newProgram(). It is incredibly unlikely that the factory call will generate an error given - // the factory test performed within the Program() call. - p, err := gen.factory(state, costTracker) - if err != nil { - return nil, det, err - } - - // Evaluate the input, returning the result and the 'state' within EvalDetails. - v, _, err := p.ContextEval(ctx, input) - if err != nil { - return v, det, err - } - return v, det, nil -} - type ctxEvalActivation struct { parent Activation interrupt <-chan struct{} diff --git a/interpreter/attributes_test.go b/interpreter/attributes_test.go index 86b6025f1..0e16f208f 100644 --- a/interpreter/attributes_test.go +++ b/interpreter/attributes_test.go @@ -1150,8 +1150,15 @@ func TestAttributeStateTracking(t *testing.T) { } interp := newStandardInterpreter(t, cont, reg, reg, attrs) // Show that program planning will now produce an error. - st := NewEvalState() - i, err := interp.NewInterpretable(checked, Optimize(), Observe(EvalStateObserver(st))) + type stateHolder struct { + st EvalState + } + holder := stateHolder{} + i, err := interp.NewInterpretable(checked, Optimize(), + EvalStateObserver(EvalStateFactory(func() EvalState { + holder.st = NewEvalState() + return holder.st + }))) if err != nil { t.Fatal(err) } @@ -1167,10 +1174,10 @@ func TestAttributeStateTracking(t *testing.T) { t.Errorf("got %v, wanted %v", out, tc.out) } for id, val := range tc.state { - stVal, found := st.Value(id) + stVal, found := holder.st.Value(id) if !found { - for _, id := range st.IDs() { - v, _ := st.Value(id) + for _, id := range holder.st.IDs() { + v, _ := holder.st.Value(id) t.Error(id, v) } t.Errorf("state not found for %d=%v", id, val) diff --git a/interpreter/interpretable.go b/interpreter/interpretable.go index 1573523be..1990ce017 100644 --- a/interpreter/interpretable.go +++ b/interpreter/interpretable.go @@ -109,6 +109,44 @@ type InterpretableConstructor interface { Type() ref.Type } +// ObservableInterpretable is an Interpretable which supports stateful observation, such as tracing +// or cost-tracking. +type ObservableInterpretable struct { + Interpretable + observers []StatefulObserver +} + +// ID implements the Interpretable method to get the expression id associated with the step. +func (oi *ObservableInterpretable) ID() int64 { + return oi.Interpretable.ID() +} + +// Eval proxies to the ObserveEval method while invoking a no-op callback to report the observations. +func (oi *ObservableInterpretable) Eval(vars Activation) ref.Val { + return oi.ObserveEval(vars, func(any) {}) +} + +// ObserveEval evaluates an interpretable and performs per-evaluation state-tracking. +// +// This method is concurrency safe and the expectation is that the observer function will use +// a switch statement to determine the type of the state which has been reported back from the call. +func (oi *ObservableInterpretable) ObserveEval(vars Activation, observer func(any)) ref.Val { + var err error + // Initialize the state needed for the observers to function. + for _, obs := range oi.observers { + vars, err = obs.InitState(vars) + if err != nil { + return types.WrapErr(err) + } + } + result := oi.Interpretable.Eval(vars) + // Get the state which needs to be reported back as having been observed. + for _, obs := range oi.observers { + observer(obs.GetState(vars)) + } + return result +} + // Core Interpretable implementations used during the program planning phase. type evalTestOnly struct { @@ -822,9 +860,9 @@ type evalWatch struct { } // Eval implements the Interpretable interface method. -func (e *evalWatch) Eval(ctx Activation) ref.Val { - val := e.Interpretable.Eval(ctx) - e.observer(e.ID(), e.Interpretable, val) +func (e *evalWatch) Eval(vars Activation) ref.Val { + val := e.Interpretable.Eval(vars) + e.observer(vars, e.ID(), e.Interpretable, val) return val } @@ -883,7 +921,7 @@ func (e *evalWatchAttr) AddQualifier(q Qualifier) (Attribute, error) { // Eval implements the Interpretable interface method. func (e *evalWatchAttr) Eval(vars Activation) ref.Val { val := e.InterpretableAttribute.Eval(vars) - e.observer(e.ID(), e.InterpretableAttribute, val) + e.observer(vars, e.ID(), e.InterpretableAttribute, val) return val } @@ -904,7 +942,7 @@ func (e *evalWatchConstQual) Qualify(vars Activation, obj any) (any, error) { } else { val = e.adapter.NativeToValue(out) } - e.observer(e.ID(), e.ConstantQualifier, val) + e.observer(vars, e.ID(), e.ConstantQualifier, val) return out, err } @@ -920,7 +958,7 @@ func (e *evalWatchConstQual) QualifyIfPresent(vars Activation, obj any, presence val = types.Bool(present) } if present || presenceOnly { - e.observer(e.ID(), e.ConstantQualifier, val) + e.observer(vars, e.ID(), e.ConstantQualifier, val) } return out, present, err } @@ -947,7 +985,7 @@ func (e *evalWatchAttrQual) Qualify(vars Activation, obj any) (any, error) { } else { val = e.adapter.NativeToValue(out) } - e.observer(e.ID(), e.Attribute, val) + e.observer(vars, e.ID(), e.Attribute, val) return out, err } @@ -963,7 +1001,7 @@ func (e *evalWatchAttrQual) QualifyIfPresent(vars Activation, obj any, presenceO val = types.Bool(present) } if present || presenceOnly { - e.observer(e.ID(), e.Attribute, val) + e.observer(vars, e.ID(), e.Attribute, val) } return out, present, err } @@ -984,7 +1022,7 @@ func (e *evalWatchQual) Qualify(vars Activation, obj any) (any, error) { } else { val = e.adapter.NativeToValue(out) } - e.observer(e.ID(), e.Qualifier, val) + e.observer(vars, e.ID(), e.Qualifier, val) return out, err } @@ -1000,7 +1038,7 @@ func (e *evalWatchQual) QualifyIfPresent(vars Activation, obj any, presenceOnly val = types.Bool(present) } if present || presenceOnly { - e.observer(e.ID(), e.Qualifier, val) + e.observer(vars, e.ID(), e.Qualifier, val) } return out, present, err } @@ -1014,7 +1052,7 @@ type evalWatchConst struct { // Eval implements the Interpretable interface method. func (e *evalWatchConst) Eval(vars Activation) ref.Val { val := e.Value() - e.observer(e.ID(), e.InterpretableConst, val) + e.observer(vars, e.ID(), e.InterpretableConst, val) return val } @@ -1187,13 +1225,13 @@ func (a *evalAttr) Eval(ctx Activation) ref.Val { } // Qualify proxies to the Attribute's Qualify method. -func (a *evalAttr) Qualify(ctx Activation, obj any) (any, error) { - return a.attr.Qualify(ctx, obj) +func (a *evalAttr) Qualify(vars Activation, obj any) (any, error) { + return a.attr.Qualify(vars, obj) } // QualifyIfPresent proxies to the Attribute's QualifyIfPresent method. -func (a *evalAttr) QualifyIfPresent(ctx Activation, obj any, presenceOnly bool) (any, bool, error) { - return a.attr.QualifyIfPresent(ctx, obj, presenceOnly) +func (a *evalAttr) QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) { + return a.attr.QualifyIfPresent(vars, obj, presenceOnly) } func (a *evalAttr) IsOptional() bool { @@ -1226,9 +1264,9 @@ func (c *evalWatchConstructor) ID() int64 { } // Eval implements the Interpretable Eval function. -func (c *evalWatchConstructor) Eval(ctx Activation) ref.Val { - val := c.constructor.Eval(ctx) - c.observer(c.ID(), c.constructor, val) +func (c *evalWatchConstructor) Eval(vars Activation) ref.Val { + val := c.constructor.Eval(vars) + c.observer(vars, c.ID(), c.constructor, val) return val } diff --git a/interpreter/interpreter.go b/interpreter/interpreter.go index 0aca74d88..be57e7439 100644 --- a/interpreter/interpreter.go +++ b/interpreter/interpreter.go @@ -18,36 +18,41 @@ package interpreter import ( + "errors" + "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/containers" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" ) +// PlannerOption configures the program plan options during interpretable setup. +type PlannerOption func(*planner) (*planner, error) + // Interpreter generates a new Interpretable from a checked or unchecked expression. type Interpreter interface { // NewInterpretable creates an Interpretable from a checked expression and an - // optional list of InterpretableDecorator values. - NewInterpretable(exprAST *ast.AST, decorators ...InterpretableDecorator) (Interpretable, error) + // optional list of PlannerOption values. + NewInterpretable(exprAST *ast.AST, opts ...PlannerOption) (Interpretable, error) } // EvalObserver is a functional interface that accepts an expression id and an observed value. // The id identifies the expression that was evaluated, the programStep is the Interpretable or Qualifier that // was evaluated and value is the result of the evaluation. -type EvalObserver func(id int64, programStep any, value ref.Val) +type EvalObserver func(vars Activation, id int64, programStep any, value ref.Val) -// Observe constructs a decorator that calls all the provided observers in order after evaluating each Interpretable -// or Qualifier during program evaluation. -func Observe(observers ...EvalObserver) InterpretableDecorator { - if len(observers) == 1 { - return decObserveEval(observers[0]) - } - observeFn := func(id int64, programStep any, val ref.Val) { - for _, observer := range observers { - observer(id, programStep, val) - } - } - return decObserveEval(observeFn) +// StatefulObserver observes evaluation while tracking or utilizing stateful behavior. +type StatefulObserver interface { + // InitState configures stateful metadata on the activation. + InitState(Activation) (Activation, error) + + // GetState retrieves the stateful metadata from the activation. + GetState(Activation) any + + // Observe passes the activation and relevant evaluation metadata to the observer. + // The observe method is expected to do the equivalent of GetState(vars) in order + // to find the metadata that needs to be updated upon invocation. + Observe(vars Activation, id int64, programStep any, value ref.Val) } // EvalCancelledError represents a cancelled program evaluation operation. @@ -73,24 +78,110 @@ const ( CostLimitExceeded ) -// TODO: Replace all usages of TrackState with EvalStateObserver +// evalStateOption configures the evalStateFactory behavior. +type evalStateOption func(*evalStateFactory) *evalStateFactory + +// EvalStateFactory configures the EvalState generator to be used by the EvalStateObserver. +func EvalStateFactory(factory func() EvalState) evalStateOption { + return func(fac *evalStateFactory) *evalStateFactory { + fac.factory = factory + return fac + } +} + +// EvalStateObserver provides an observer which records the value associated with the given expression id. +// EvalState must be provided to the observer. +func EvalStateObserver(opts ...evalStateOption) PlannerOption { + et := &evalStateFactory{factory: NewEvalState} + for _, o := range opts { + et = o(et) + } + return func(p *planner) (*planner, error) { + if et.factory == nil { + return nil, errors.New("eval state factory not configured") + } + p.observers = append(p.observers, et) + p.decorators = append(p.decorators, decObserveEval(et.Observe)) + return p, nil + } +} + +// evalStateConverter identifies an object which is convertible to an EvalState instance. +type evalStateConverter interface { + asEvalState() EvalState +} + +// evalStateActivation hides state in the Activation in a manner not accessible to expressions. +type evalStateActivation struct { + vars Activation + state EvalState +} + +// ResolveName proxies variable lookups to the backing activation. +func (esa evalStateActivation) ResolveName(name string) (any, bool) { + return esa.vars.ResolveName(name) +} + +// Parent proxies parent lookups to the backing activation. +func (esa evalStateActivation) Parent() Activation { + return esa.vars +} + +// AsPartialActivation supports conversion to a partial activation in order to detect unknown attributes. +func (esa evalStateActivation) AsPartialActivation() (PartialActivation, bool) { + return AsPartialActivation(esa.vars) +} -// TrackState decorates each expression node with an observer which records the value -// associated with the given expression id. EvalState must be provided to the decorator. -// This decorator is not thread-safe, and the EvalState must be reset between Eval() -// calls. -// DEPRECATED: Please use EvalStateObserver instead. It composes gracefully with additional observers. -func TrackState(state EvalState) InterpretableDecorator { - return Observe(EvalStateObserver(state)) +// asEvalState implements the evalStateConverter method. +func (esa evalStateActivation) asEvalState() EvalState { + return esa.state } -// EvalStateObserver provides an observer which records the value -// associated with the given expression id. EvalState must be provided to the observer. -// This decorator is not thread-safe, and the EvalState must be reset between Eval() -// calls. -func EvalStateObserver(state EvalState) EvalObserver { - return func(id int64, programStep any, val ref.Val) { - state.SetValue(id, val) +// asEvalState walks the Activation hierarchy and returns the first EvalState found, if present. +func asEvalState(vars Activation) (EvalState, bool) { + if conv, ok := vars.(evalStateConverter); ok { + return conv.asEvalState(), true + } + if vars.Parent() != nil { + return asEvalState(vars.Parent()) + } + return nil, false +} + +// evalStateFactory holds a reference to a factory function that produces an EvalState instance. +type evalStateFactory struct { + factory func() EvalState +} + +// InitState produces an EvalState instance and bundles it into the Activation in a way which is +// not visible to expression evaluation. +func (et *evalStateFactory) InitState(vars Activation) (Activation, error) { + state := et.factory() + return evalStateActivation{vars: vars, state: state}, nil +} + +// GetState extracts the EvalState from the Activation. +func (et *evalStateFactory) GetState(vars Activation) any { + if state, found := asEvalState(vars); found { + return state + } + return nil +} + +// Observe records the evaluation state for a given expression node and program step. +func (et *evalStateFactory) Observe(vars Activation, id int64, programStep any, val ref.Val) { + state, found := asEvalState(vars) + if !found { + return + } + state.SetValue(id, val) +} + +// CustomDecorator configures a custom interpretable decorator for the program. +func CustomDecorator(dec InterpretableDecorator) PlannerOption { + return func(p *planner) (*planner, error) { + p.decorators = append(p.decorators, dec) + return p, nil } } @@ -99,11 +190,8 @@ func EvalStateObserver(state EvalState) EvalObserver { // insight into the evaluation state of the entire expression. EvalState must be // provided to the decorator. This decorator is not thread-safe, and the EvalState // must be reset between Eval() calls. -func ExhaustiveEval() InterpretableDecorator { - ex := decDisableShortcircuits() - return func(i Interpretable) (Interpretable, error) { - return ex(i) - } +func ExhaustiveEval() PlannerOption { + return CustomDecorator(decDisableShortcircuits()) } // InterruptableEval annotates comprehension loops with information that indicates they @@ -111,14 +199,14 @@ func ExhaustiveEval() InterpretableDecorator { // // The custom activation is currently managed higher up in the stack within the 'cel' package // and should not require any custom support on behalf of callers. -func InterruptableEval() InterpretableDecorator { - return decInterruptFolds() +func InterruptableEval() PlannerOption { + return CustomDecorator(decInterruptFolds()) } // Optimize will pre-compute operations such as list and map construction and optimize // call arguments to set membership tests. The set of optimizations will increase over time. -func Optimize() InterpretableDecorator { - return decOptimize() +func Optimize() PlannerOption { + return CustomDecorator(decOptimize()) } // RegexOptimization provides a way to replace an InterpretableCall for a regex function when the @@ -142,8 +230,8 @@ type RegexOptimization struct { // CompileRegexConstants compiles regex pattern string constants at program creation time and reports any regex pattern // compile errors. -func CompileRegexConstants(regexOptimizations ...*RegexOptimization) InterpretableDecorator { - return decRegexOptimizer(regexOptimizations...) +func CompileRegexConstants(regexOptimizations ...*RegexOptimization) PlannerOption { + return CustomDecorator(decRegexOptimizer(regexOptimizations...)) } type exprInterpreter struct { @@ -172,14 +260,14 @@ func NewInterpreter(dispatcher Dispatcher, // NewIntepretable implements the Interpreter interface method. func (i *exprInterpreter) NewInterpretable( checked *ast.AST, - decorators ...InterpretableDecorator) (Interpretable, error) { - p := newPlanner( - i.dispatcher, - i.provider, - i.adapter, - i.attrFactory, - i.container, - checked, - decorators...) + opts ...PlannerOption) (Interpretable, error) { + p := newPlanner(i.dispatcher, i.provider, i.adapter, i.attrFactory, i.container, checked) + var err error + for _, o := range opts { + p, err = o(p) + if err != nil { + return nil, err + } + } return p.Plan(checked.Expr()) } diff --git a/interpreter/interpreter_test.go b/interpreter/interpreter_test.go index 498be1340..0f1057c42 100644 --- a/interpreter/interpreter_test.go +++ b/interpreter/interpreter_test.go @@ -59,7 +59,7 @@ type testCase struct { funcs []*decls.FunctionDecl attrs AttributeFactory unchecked bool - extraOpts []InterpretableDecorator + extraOpts []PlannerOption in any out any @@ -900,7 +900,7 @@ func testData(t testing.TB) []testCase { in: map[string]any{ "input": "kathmandu", }, - extraOpts: []InterpretableDecorator{CompileRegexConstants(MatchesRegexOptimization)}, + extraOpts: []PlannerOption{CompileRegexConstants(MatchesRegexOptimization)}, // unoptimized program should report a regex compile error at runtime err: "unexpected ): `)k.*`", // optimized program should report a regex compile at program creation time @@ -1584,10 +1584,11 @@ func TestInterpreter(t *testing.T) { } state := NewEvalState() - opts := map[string][]InterpretableDecorator{ - "optimize": {Optimize()}, - "exhaustive": {ExhaustiveEval(), Observe(EvalStateObserver(state))}, - "track": {Observe(EvalStateObserver(state))}, + opts := map[string][]PlannerOption{ + "optimize": {Optimize()}, + "exhaustive": {ExhaustiveEval(), + EvalStateObserver(EvalStateFactory(func() EvalState { return state }))}, + "track": {EvalStateObserver(EvalStateFactory(func() EvalState { return state }))}, } for mode, opt := range opts { opts := opt @@ -1698,8 +1699,8 @@ func TestInterpreter_ExhaustiveConditionalExpr(t *testing.T) { reg := newTestRegistry(t, &exprpb.ParsedExpr{}) attrs := NewAttributeFactory(cont, reg, reg) intr := newStandardInterpreter(t, cont, reg, reg, attrs) - interpretable, _ := intr.NewInterpretable(parsed, - ExhaustiveEval(), Observe(EvalStateObserver(state))) + interpretable, _ := intr.NewInterpretable(parsed, ExhaustiveEval(), + EvalStateObserver(EvalStateFactory(func() EvalState { return state }))) vars, _ := NewActivation(map[string]any{ "a": types.True, "b": types.Double(0.999), @@ -1781,8 +1782,8 @@ func TestInterpreter_ExhaustiveLogicalOrEquals(t *testing.T) { cont := testContainer("test") attrs := NewAttributeFactory(cont, reg, reg) interp := newStandardInterpreter(t, cont, reg, reg, attrs) - i, _ := interp.NewInterpretable(parsed, - ExhaustiveEval(), Observe(EvalStateObserver(state))) + i, _ := interp.NewInterpretable(parsed, ExhaustiveEval(), + EvalStateObserver(EvalStateFactory(func() EvalState { return state }))) vars, _ := NewActivation(map[string]any{ "a": true, "b": "b", @@ -2080,7 +2081,7 @@ func testContainer(name string) *containers.Container { return cont } -func program(t testing.TB, tst *testCase, opts ...InterpretableDecorator) (Interpretable, Activation, error) { +func program(t testing.TB, tst *testCase, opts ...PlannerOption) (Interpretable, Activation, error) { // Configure the package. cont := containers.DefaultContainer if tst.container != "" { diff --git a/interpreter/planner.go b/interpreter/planner.go index f0fd4eaf9..f0e0d4305 100644 --- a/interpreter/planner.go +++ b/interpreter/planner.go @@ -25,12 +25,6 @@ import ( "github.com/google/cel-go/common/types" ) -// interpretablePlanner creates an Interpretable evaluation plan from a proto Expr value. -type interpretablePlanner interface { - // Plan generates an Interpretable value (or error) from the input proto Expr. - Plan(expr ast.Expr) (Interpretable, error) -} - // newPlanner creates an interpretablePlanner which references a Dispatcher, TypeProvider, // TypeAdapter, Container, and CheckedExpr value. These pieces of data are used to resolve // functions, types, and namespaced identifiers at plan time rather than at runtime since @@ -40,8 +34,7 @@ func newPlanner(disp Dispatcher, adapter types.Adapter, attrFactory AttributeFactory, cont *containers.Container, - exprAST *ast.AST, - decorators ...InterpretableDecorator) interpretablePlanner { + exprAST *ast.AST) *planner { return &planner{ disp: disp, provider: provider, @@ -50,7 +43,8 @@ func newPlanner(disp Dispatcher, container: cont, refMap: exprAST.ReferenceMap(), typeMap: exprAST.TypeMap(), - decorators: decorators, + decorators: make([]InterpretableDecorator, 0), + observers: make([]StatefulObserver, 0), } } @@ -64,6 +58,7 @@ type planner struct { refMap map[int64]*ast.ReferenceInfo typeMap map[int64]*types.Type decorators []InterpretableDecorator + observers []StatefulObserver } // Plan implements the interpretablePlanner interface. This implementation of the Plan method also @@ -72,6 +67,17 @@ type planner struct { // such as state-tracking, expression re-write, and possibly efficient thread-safe memoization of // repeated expressions. func (p *planner) Plan(expr ast.Expr) (Interpretable, error) { + i, err := p.plan(expr) + if err != nil { + return nil, err + } + if len(p.observers) == 0 { + return i, nil + } + return &ObservableInterpretable{Interpretable: i, observers: p.observers}, nil +} + +func (p *planner) plan(expr ast.Expr) (Interpretable, error) { switch expr.Kind() { case ast.CallKind: return p.decorate(p.planCall(expr)) @@ -161,7 +167,7 @@ func (p *planner) planSelect(expr ast.Expr) (Interpretable, error) { sel := expr.AsSelect() // Plan the operand evaluation. - op, err := p.Plan(sel.Operand()) + op, err := p.plan(sel.Operand()) if err != nil { return nil, err } @@ -220,14 +226,14 @@ func (p *planner) planCall(expr ast.Expr) (Interpretable, error) { args := make([]Interpretable, argCount) if target != nil { - arg, err := p.Plan(target) + arg, err := p.plan(target) if err != nil { return nil, err } args[0] = arg } for i, argExpr := range call.Args() { - arg, err := p.Plan(argExpr) + arg, err := p.plan(argExpr) if err != nil { return nil, err } @@ -496,7 +502,7 @@ func (p *planner) planCreateList(expr ast.Expr) (Interpretable, error) { } elems := make([]Interpretable, len(elements)) for i, elem := range elements { - elemVal, err := p.Plan(elem) + elemVal, err := p.plan(elem) if err != nil { return nil, err } @@ -521,13 +527,13 @@ func (p *planner) planCreateMap(expr ast.Expr) (Interpretable, error) { hasOptionals := false for i, e := range entries { entry := e.AsMapEntry() - keyVal, err := p.Plan(entry.Key()) + keyVal, err := p.plan(entry.Key()) if err != nil { return nil, err } keys[i] = keyVal - valVal, err := p.Plan(entry.Value()) + valVal, err := p.plan(entry.Value()) if err != nil { return nil, err } @@ -560,7 +566,7 @@ func (p *planner) planCreateStruct(expr ast.Expr) (Interpretable, error) { for i, f := range objFields { field := f.AsStructField() fields[i] = field.Name() - val, err := p.Plan(field.Value()) + val, err := p.plan(field.Value()) if err != nil { return nil, err } @@ -582,23 +588,23 @@ func (p *planner) planCreateStruct(expr ast.Expr) (Interpretable, error) { // planComprehension generates an Interpretable fold operation. func (p *planner) planComprehension(expr ast.Expr) (Interpretable, error) { fold := expr.AsComprehension() - accu, err := p.Plan(fold.AccuInit()) + accu, err := p.plan(fold.AccuInit()) if err != nil { return nil, err } - iterRange, err := p.Plan(fold.IterRange()) + iterRange, err := p.plan(fold.IterRange()) if err != nil { return nil, err } - cond, err := p.Plan(fold.LoopCondition()) + cond, err := p.plan(fold.LoopCondition()) if err != nil { return nil, err } - step, err := p.Plan(fold.LoopStep()) + step, err := p.plan(fold.LoopStep()) if err != nil { return nil, err } - result, err := p.Plan(fold.Result()) + result, err := p.plan(fold.Result()) if err != nil { return nil, err } diff --git a/interpreter/prune_test.go b/interpreter/prune_test.go index 7a94c5270..e333c1ef5 100644 --- a/interpreter/prune_test.go +++ b/interpreter/prune_test.go @@ -498,7 +498,7 @@ func TestPrune(t *testing.T) { dispatcher.Add(funcBindings(t, optionalDecls(t)...)...) interp := NewInterpreter(dispatcher, containers.DefaultContainer, reg, reg, attrs) interpretable, err := interp.NewInterpretable(parsed, - ExhaustiveEval(), Observe(EvalStateObserver(state))) + ExhaustiveEval(), EvalStateObserver(EvalStateFactory(func() EvalState { return state }))) if err != nil { t.Fatalf("NewUncheckedInterpretable() failed: %v", err) } diff --git a/interpreter/runtimecost.go b/interpreter/runtimecost.go index 8f47c53d2..6c44cd798 100644 --- a/interpreter/runtimecost.go +++ b/interpreter/runtimecost.go @@ -15,6 +15,7 @@ package interpreter import ( + "errors" "math" "github.com/google/cel-go/common" @@ -34,78 +35,172 @@ type ActualCostEstimator interface { CallCost(function, overloadID string, args []ref.Val, result ref.Val) *uint64 } +// costTrackPlanOption modifies the cost tracking factory associatied with the CostObserver +type costTrackPlanOption func(*costTrackerFactory) *costTrackerFactory + +// CostTrackerFactory configures the factory method to generate a new cost-tracker per-evaluation. +func CostTrackerFactory(factory func() (*CostTracker, error)) costTrackPlanOption { + return func(fac *costTrackerFactory) *costTrackerFactory { + fac.factory = factory + return fac + } +} + // CostObserver provides an observer that tracks runtime cost. -func CostObserver(tracker *CostTracker) EvalObserver { - observer := func(id int64, programStep any, val ref.Val) { - switch t := programStep.(type) { - case ConstantQualifier: - // TODO: Push identifiers on to the stack before observing constant qualifiers that apply to them - // and enable the below pop. Once enabled this can case can be collapsed into the Qualifier case. - tracker.cost++ - case InterpretableConst: - // zero cost - case InterpretableAttribute: - switch a := t.Attr().(type) { - case *conditionalAttribute: - // Ternary has no direct cost. All cost is from the conditional and the true/false branch expressions. - tracker.stack.drop(a.falsy.ID(), a.truthy.ID(), a.expr.ID()) - default: - tracker.stack.drop(t.Attr().ID()) - tracker.cost += common.SelectAndIdentCost - } - if !tracker.presenceTestHasCost { - if _, isTestOnly := programStep.(*evalTestOnly); isTestOnly { - tracker.cost -= common.SelectAndIdentCost - } - } - case *evalExhaustiveConditional: - // Ternary has no direct cost. All cost is from the conditional and the true/false branch expressions. - tracker.stack.drop(t.attr.falsy.ID(), t.attr.truthy.ID(), t.attr.expr.ID()) +func CostObserver(opts ...costTrackPlanOption) PlannerOption { + ct := &costTrackerFactory{} + for _, o := range opts { + ct = o(ct) + } + return func(p *planner) (*planner, error) { + if ct.factory == nil { + return nil, errors.New("cost tracker factory not configured") + } + p.observers = append(p.observers, ct) + p.decorators = append(p.decorators, decObserveEval(ct.Observe)) + return p, nil + } +} - // While the field names are identical, the boolean operation eval structs do not share an interface and so - // must be handled individually. - case *evalOr: - for _, term := range t.terms { - tracker.stack.drop(term.ID()) - } - case *evalAnd: - for _, term := range t.terms { - tracker.stack.drop(term.ID()) - } - case *evalExhaustiveOr: - for _, term := range t.terms { - tracker.stack.drop(term.ID()) - } - case *evalExhaustiveAnd: - for _, term := range t.terms { - tracker.stack.drop(term.ID()) - } - case *evalFold: - tracker.stack.drop(t.iterRange.ID()) - case Qualifier: - tracker.cost++ - case InterpretableCall: - if argVals, ok := tracker.stack.dropArgs(t.Args()); ok { - tracker.cost += tracker.costCall(t, argVals, val) - } - case InterpretableConstructor: - tracker.stack.dropArgs(t.InitVals()) - switch t.Type() { - case types.ListType: - tracker.cost += common.ListCreateBaseCost - case types.MapType: - tracker.cost += common.MapCreateBaseCost - default: - tracker.cost += common.StructCreateBaseCost +// costTrackerConverter identifies an object which is convertible to a CostTracker instance. +type costTrackerConverter interface { + asCostTracker() *CostTracker +} + +// costTrackActivation hides state in the Activation in a manner not accessible to expressions. +type costTrackActivation struct { + vars Activation + costTracker *CostTracker +} + +// ResolveName proxies variable lookups to the backing activation. +func (cta costTrackActivation) ResolveName(name string) (any, bool) { + return cta.vars.ResolveName(name) +} + +// Parent proxies parent lookups to the backing activation. +func (cta costTrackActivation) Parent() Activation { + return cta.vars +} + +// AsPartialActivation supports conversion to a partial activation in order to detect unknown attributes. +func (cta costTrackActivation) AsPartialActivation() (PartialActivation, bool) { + return AsPartialActivation(cta.vars) +} + +// asCostTracker implements the costTrackerConverter method. +func (cta costTrackActivation) asCostTracker() *CostTracker { + return cta.costTracker +} + +// asCostTracker walks the Activation hierarchy and returns the first cost tracker found, if present. +func asCostTracker(vars Activation) (*CostTracker, bool) { + if conv, ok := vars.(costTrackerConverter); ok { + return conv.asCostTracker(), true + } + if vars.Parent() != nil { + return asCostTracker(vars.Parent()) + } + return nil, false +} + +// costTrackerFactory holds a factory for producing new CostTracker instances on each Eval call. +type costTrackerFactory struct { + factory func() (*CostTracker, error) +} + +// InitState produces a CostTracker and bundles it into an Activation in a way which is not visible +// to expression evaluation. +func (ct *costTrackerFactory) InitState(vars Activation) (Activation, error) { + tracker, err := ct.factory() + if err != nil { + return nil, err + } + return costTrackActivation{vars: vars, costTracker: tracker}, nil +} + +// GetState extracts the CostTracker from the Activation. +func (ct *costTrackerFactory) GetState(vars Activation) any { + if tracker, found := asCostTracker(vars); found { + return tracker + } + return nil +} + +// Observe computes the incremental cost of each step and records it into the CostTracker associated +// with the evaluation. +func (ct *costTrackerFactory) Observe(vars Activation, id int64, programStep any, val ref.Val) { + tracker, found := asCostTracker(vars) + if !found { + return + } + switch t := programStep.(type) { + case ConstantQualifier: + // TODO: Push identifiers on to the stack before observing constant qualifiers that apply to them + // and enable the below pop. Once enabled this can case can be collapsed into the Qualifier case. + tracker.cost++ + case InterpretableConst: + // zero cost + case InterpretableAttribute: + switch a := t.Attr().(type) { + case *conditionalAttribute: + // Ternary has no direct cost. All cost is from the conditional and the true/false branch expressions. + tracker.stack.drop(a.falsy.ID(), a.truthy.ID(), a.expr.ID()) + default: + tracker.stack.drop(t.Attr().ID()) + tracker.cost += common.SelectAndIdentCost + } + if !tracker.presenceTestHasCost { + if _, isTestOnly := programStep.(*evalTestOnly); isTestOnly { + tracker.cost -= common.SelectAndIdentCost } } - tracker.stack.push(val, id) + case *evalExhaustiveConditional: + // Ternary has no direct cost. All cost is from the conditional and the true/false branch expressions. + tracker.stack.drop(t.attr.falsy.ID(), t.attr.truthy.ID(), t.attr.expr.ID()) - if tracker.Limit != nil && tracker.cost > *tracker.Limit { - panic(EvalCancelledError{Cause: CostLimitExceeded, Message: "operation cancelled: actual cost limit exceeded"}) + // While the field names are identical, the boolean operation eval structs do not share an interface and so + // must be handled individually. + case *evalOr: + for _, term := range t.terms { + tracker.stack.drop(term.ID()) + } + case *evalAnd: + for _, term := range t.terms { + tracker.stack.drop(term.ID()) } + case *evalExhaustiveOr: + for _, term := range t.terms { + tracker.stack.drop(term.ID()) + } + case *evalExhaustiveAnd: + for _, term := range t.terms { + tracker.stack.drop(term.ID()) + } + case *evalFold: + tracker.stack.drop(t.iterRange.ID()) + case Qualifier: + tracker.cost++ + case InterpretableCall: + if argVals, ok := tracker.stack.dropArgs(t.Args()); ok { + tracker.cost += tracker.costCall(t, argVals, val) + } + case InterpretableConstructor: + tracker.stack.dropArgs(t.InitVals()) + switch t.Type() { + case types.ListType: + tracker.cost += common.ListCreateBaseCost + case types.MapType: + tracker.cost += common.MapCreateBaseCost + default: + tracker.cost += common.StructCreateBaseCost + } + } + tracker.stack.push(val, id) + + if tracker.Limit != nil && tracker.cost > *tracker.Limit { + panic(EvalCancelledError{Cause: CostLimitExceeded, Message: "operation cancelled: actual cost limit exceeded"}) } - return observer } // CostTrackerOption configures the behavior of CostTracker objects. diff --git a/interpreter/runtimecost_test.go b/interpreter/runtimecost_test.go index 6e6bc6b54..54e10f142 100644 --- a/interpreter/runtimecost_test.go +++ b/interpreter/runtimecost_test.go @@ -143,7 +143,10 @@ func computeCost(t *testing.T, expr string, vars []*decls.VariableDecl, ctx Acti t.Fatalf("checker.Cost() failed: %v", err) } interp := newStandardInterpreter(t, cont, reg, reg, attrs) - prg, err := interp.NewInterpretable(checked, Observe(CostObserver(costTracker))) + prg, err := interp.NewInterpretable(checked, + CostObserver(CostTrackerFactory(func() (*CostTracker, error) { + return costTracker, nil + }))) if err != nil { t.Fatalf(`Failed to check expression "%s", error: %v`, expr, errs.GetErrors()) } From a6fbac99134e665e6eb13b21ee4f482300ff94d0 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Tue, 8 Apr 2025 12:37:56 -0700 Subject: [PATCH 26/46] Utilities for formatting and parsing documentation strings (#1155) * Utilities for formatting and parsing documentation strings * Additional test case --- common/BUILD.bazel | 2 + common/doc.go | 194 +++++++++++++++++++++++++++++++++++++++++ common/doc_test.go | 210 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 406 insertions(+) create mode 100644 common/doc_test.go diff --git a/common/BUILD.bazel b/common/BUILD.bazel index eef7f281b..1b1b7914d 100644 --- a/common/BUILD.bazel +++ b/common/BUILD.bazel @@ -9,6 +9,7 @@ go_library( name = "go_default_library", srcs = [ "cost.go", + "doc.go", "error.go", "errors.go", "location.go", @@ -25,6 +26,7 @@ go_test( name = "go_default_test", size = "small", srcs = [ + "doc_test.go", "errors_test.go", "source_test.go", ], diff --git a/common/doc.go b/common/doc.go index 5362fdfe4..c68df818f 100644 --- a/common/doc.go +++ b/common/doc.go @@ -15,3 +15,197 @@ // Package common defines types and utilities common to expression parsing, // checking, and interpretation package common + +import ( + "fmt" + "strings" + "unicode" +) + +// DocKind indicates the type of documentation element. +type DocKind int + +const ( + // DocEnv represents environment variable documentation. + DocEnv DocKind = iota + 1 + // DocFunction represents function documentation. + DocFunction + // DocOverload represents function overload documentation. + DocOverload + // DocVariable represents variable documentation. + DocVariable + // DocMacro represents macro documentation. + DocMacro + // DocExample represents example documentation. + DocExample +) + +// MultilineDescription represents a description that can span multiple lines, +// stored as a slice of strings. +type MultilineDescription []string + +// Doc holds the documentation details for a specific program element like +// a variable, function, macro, or example. +type Doc struct { + // Kind specifies the type of documentation element (e.g., Function, Variable). + Kind DocKind + + // Name is the identifier of the documented element (e.g., function name, variable name). + Name string + + // Type is the data type associated with the element, primarily used for variables. + Type string + + // Signature represents the function or overload signature. + Signature string + + // Description holds the textual description of the element, potentially spanning multiple lines. + Description MultilineDescription + + // Children holds nested documentation elements, such as overloads for a function + // or examples for a function/macro. + Children []*Doc +} + +// FormatDescription joins multiple description elements (string, MultilineDescription, +// or []MultilineDescription) into a single string, separated by double newlines ("\n\n"). +// It returns the formatted string or an error if an unsupported type is encountered. +func FormatDescription(descs ...any) (string, error) { + return FormatDescriptionSeparator("\n\n", descs...) +} + +// FormatDescriptionSeparator joins multiple description elements (string, MultilineDescription, +// or []MultilineDescription) into a single string using the specified separator. +// It returns the formatted string or an error if an unsupported description type is passed. +func FormatDescriptionSeparator(sep string, descs ...any) (string, error) { + var builder strings.Builder + hasDoc := false + for _, d := range descs { + if hasDoc { + builder.WriteString(sep) + } + switch v := d.(type) { + case string: + builder.WriteString(v) + case MultilineDescription: + str := strings.Join(v, "\n") + builder.WriteString(str) + case []MultilineDescription: + for _, md := range v { + if hasDoc { + builder.WriteString(sep) + } + str := strings.Join(md, "\n") + builder.WriteString(str) + hasDoc = true + } + default: + return "", fmt.Errorf("unsupported description type: %T", d) + } + hasDoc = true + } + return builder.String(), nil +} + +// ParseDescription takes a single string containing newline characters and splits +// it into a MultilineDescription. All empty lines will be skipped. +// +// Returns an empty MultilineDescription if the input string is empty. +func ParseDescription(doc string) MultilineDescription { + var lines MultilineDescription + if len(doc) != 0 { + // Split the input string by newline characters. + for _, line := range strings.Split(doc, "\n") { + l := strings.TrimRightFunc(line, unicode.IsSpace) + if len(l) == 0 { + continue + } + lines = append(lines, l) + } + } + // Return an empty slice if the input is empty. + return lines +} + +// ParseDescriptions splits a documentation string into multiple MultilineDescription +// sections, using blank lines as delimiters. +func ParseDescriptions(doc string) []MultilineDescription { + var examples []MultilineDescription + if len(doc) != 0 { + lines := strings.Split(doc, "\n") + lineStart := 0 + for i, l := range lines { + // Trim trailing whitespace to identify effectively blank lines. + l = strings.TrimRightFunc(l, unicode.IsSpace) + // If a line is blank, it marks the end of the current section. + if len(l) == 0 { + // Start the next section after the blank line. + ex := lines[lineStart:i] + if len(ex) != 0 { + examples = append(examples, ex) + } + lineStart = i + 1 + } + } + // Append the last section if it wasn't terminated by a blank line. + if lineStart < len(lines) { + examples = append(examples, lines[lineStart:]) + } + } + return examples +} + +// NewVariableDoc creates a new Doc struct specifically for documenting a variable. +func NewVariableDoc(name, celType, description string) *Doc { + return &Doc{ + Kind: DocVariable, + Name: name, + Type: celType, + Description: ParseDescription(description), + } +} + +// NewFunctionDoc creates a new Doc struct for documenting a function. +func NewFunctionDoc(name, description string, overloads ...*Doc) *Doc { + return &Doc{ + Kind: DocFunction, + Name: name, + Description: ParseDescription(description), + Children: overloads, + } +} + +// NewOverloadDoc creates a new Doc struct for a function example. +func NewOverloadDoc(id, signature string, examples ...*Doc) *Doc { + return &Doc{ + Kind: DocOverload, + Name: id, + Signature: signature, + Children: examples, + } +} + +// NewMacroDoc creates a new Doc struct for documenting a macro. +func NewMacroDoc(name, description string, examples ...*Doc) *Doc { + return &Doc{ + Kind: DocMacro, + Name: name, + Description: ParseDescription(description), + Children: examples, + } +} + +// NewExampleDoc creates a new Doc struct specifically for holding an example. +func NewExampleDoc(ex MultilineDescription) *Doc { + return &Doc{ + Kind: DocExample, + Description: ex, + } +} + +// Documentor is an interface for types that can provide their own documentation. +type Documentor interface { + // Documentation returns the documentation coded by the DocKind to assist + // with text formatting. + Documentation() *Doc +} diff --git a/common/doc_test.go b/common/doc_test.go new file mode 100644 index 000000000..2f06b3d8c --- /dev/null +++ b/common/doc_test.go @@ -0,0 +1,210 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package common + +import ( + "errors" + "reflect" + "strings" + "testing" +) + +func TestFormatDescription(t *testing.T) { + tests := []struct { + name string + in []any + out string + err error + }{ + { + name: "two separate examples as strings", + in: []any{"hello", "world"}, + out: "hello\n\nworld", + }, + { + name: "single example as multiline string", + in: []any{MultilineDescription{"hello", "world"}}, + out: "hello\nworld", + }, + { + name: "two examples as a list of multiline strings", + in: []any{[]MultilineDescription{{"hello", "world"}, {"goodbye", "cruel world"}}}, + out: "hello\nworld\n\ngoodbye\ncruel world", + }, + { + name: "invalid description", + in: []any{1}, + err: errors.New("unsupported description type"), + }, + } + + for _, tst := range tests { + tc := tst + t.Run(tc.name, func(t *testing.T) { + out, err := FormatDescription(tc.in...) + if err != nil { + if tc.err == nil || !strings.Contains(err.Error(), tc.err.Error()) { + t.Fatalf("FormatDescription() errored with %v, wanted %v", err, tc) + } + return + } + if out != tc.out { + t.Errorf("FormatDescription() got %s, wanted %v", out, tc) + } + }) + } +} + +func TestParseDescription(t *testing.T) { + tests := []struct { + name string + in string + out MultilineDescription + }{ + { + name: "empty", + }, + { + name: "single", + in: "hello", + out: MultilineDescription{"hello"}, + }, + { + name: "multi", + in: "hello\n\n\nworld", + out: MultilineDescription{"hello", "world"}, + }, + } + + for _, tst := range tests { + tc := tst + t.Run(tc.name, func(t *testing.T) { + out := ParseDescription(tc.in) + if !reflect.DeepEqual(out, tc.out) { + t.Errorf("ParseDescription() got %v, wanted %v", out, tc.out) + } + }) + } +} + +func TestParseDescriptions(t *testing.T) { + tests := []struct { + name string + in string + out []MultilineDescription + }{ + { + name: "empty", + }, + { + name: "single", + in: "hello", + out: []MultilineDescription{{"hello"}}, + }, + { + name: "multi", + in: "bar\nbaz\n\nfoo", + out: []MultilineDescription{{"bar", "baz"}, {"foo"}}, + }, + { + name: "multi", + in: "hello\n\n\nworld", + out: []MultilineDescription{{"hello"}, {"world"}}, + }, + } + + for _, tst := range tests { + tc := tst + t.Run(tc.name, func(t *testing.T) { + out := ParseDescriptions(tc.in) + if !reflect.DeepEqual(out, tc.out) { + t.Errorf("ParseDescriptions() got %v, wanted %v", out, tc.out) + } + }) + } +} + +func TestNewDoc(t *testing.T) { + tests := []struct { + newDoc func() *Doc + kind DocKind + name string + celType string + sig string + desc MultilineDescription + childCount int + }{ + { + newDoc: func() *Doc { + return NewMacroDoc("map", "map converts a list or map of values to a list", + NewExampleDoc(MultilineDescription{"[1, 2].map(i, i * 2) // [2, 4]"})) + }, + kind: DocMacro, + name: "map", + desc: MultilineDescription{"map converts a list or map of values to a list"}, + childCount: 1, + }, + { + newDoc: func() *Doc { + return NewVariableDoc( + "request", + "google.rpc.context.AttributeContext.Request", + "parameters related to an HTTP API request") + }, + kind: DocVariable, + name: "request", + celType: "google.rpc.context.AttributeContext.Request", + desc: MultilineDescription{"parameters related to an HTTP API request"}, + childCount: 0, + }, + { + newDoc: func() *Doc { + return NewFunctionDoc("getToken", + "get the JWT token from a request\nas deserialized JSON", + NewOverloadDoc("request_getToken", "request.getToken() -> map(string, dyn)", + NewExampleDoc(MultilineDescription{"has(request.getToken().sub) // false"}))) + }, + kind: DocFunction, + name: "getToken", + desc: MultilineDescription{"get the JWT token from a request", "as deserialized JSON"}, + childCount: 1, + }, + } + + for _, tst := range tests { + tc := tst + t.Run(tc.name, func(t *testing.T) { + d := tc.newDoc() + if d.Kind != tc.kind { + t.Errorf("got doc kind %v, wanted %v", d.Kind, tc.kind) + } + if d.Name != tc.name { + t.Errorf("got doc name %s, wanted %s", d.Name, tc.name) + } + if d.Signature != tc.sig { + t.Errorf("got signature %s, wanted %s", d.Signature, tc.sig) + } + if !reflect.DeepEqual(d.Description, tc.desc) { + t.Errorf("got description %v, wanted %v", d.Description, tc.desc) + } + if d.Type != tc.celType { + t.Errorf("got type %s, wanted %s", d.Type, tc.celType) + } + if len(d.Children) != tc.childCount { + t.Errorf("got children %v, wanted count %d", d.Children, tc.childCount) + } + }) + } +} From 997fbb2682033760563d318ad9aa6409c17f30fe Mon Sep 17 00:00:00 2001 From: Yuchen Shi Date: Wed, 9 Apr 2025 18:21:38 -0700 Subject: [PATCH 27/46] Re-export interpreter.AttributePattern in package cel. (#1158) The cel.AttributePattern name is taken by a helper function. A different name is thus picked to avoid breaking existing consumers of package cel. --- cel/program.go | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/cel/program.go b/cel/program.go index fe151eef1..24f41a4a7 100644 --- a/cel/program.go +++ b/cel/program.go @@ -75,7 +75,7 @@ func NewActivation(bindings any) (Activation, error) { return interpreter.NewActivation(bindings) } -// PartialActivation extends the Activation interface with a set of UnknownAttributePatterns. +// PartialActivation extends the Activation interface with a set of unknown AttributePatterns. type PartialActivation = interpreter.PartialActivation // NoVars returns an empty Activation. @@ -91,7 +91,7 @@ func NoVars() Activation { // // The `vars` value may either be an Activation or any valid input to the NewActivation call. func PartialVars(vars any, - unknowns ...*interpreter.AttributePattern) (PartialActivation, error) { + unknowns ...*AttributePatternType) (PartialActivation, error) { return interpreter.NewPartialActivation(vars, unknowns...) } @@ -108,12 +108,15 @@ func PartialVars(vars any, // fully qualified variable name may be `ns.app.a`, `ns.a`, or `a` per the CEL namespace resolution // rules. Pick the fully qualified variable name that makes sense within the container as the // AttributePattern `varName` argument. +func AttributePattern(varName string) *AttributePatternType { + return interpreter.NewAttributePattern(varName) +} + +// AttributePatternType represents a top-level variable with an optional set of qualifier patterns. // // See the interpreter.AttributePattern and interpreter.AttributeQualifierPattern for more info // about how to create and manipulate AttributePattern values. -func AttributePattern(varName string) *interpreter.AttributePattern { - return interpreter.NewAttributePattern(varName) -} +type AttributePatternType = interpreter.AttributePattern // EvalDetails holds additional information observed during the Eval() call. type EvalDetails struct { From bdf49d600fb8949da4f2bf1b99d66d2b5a554e3b Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Thu, 10 Apr 2025 10:56:50 -0700 Subject: [PATCH 28/46] Support for documentation and example strings in CEL environments (#1156) * Support for documentation and example strings in CEL environments --- checker/decls/decls.go | 33 +++- common/decls/BUILD.bazel | 2 + common/decls/decls.go | 171 +++++++++++++++++++- common/decls/decls_test.go | 220 +++++++++++++++++++++++++- common/doc.go | 70 ++------ common/doc_test.go | 74 ++------- common/env/BUILD.bazel | 1 + common/env/env.go | 65 +++++--- common/env/env_test.go | 75 +++++++-- common/env/testdata/context_env.yaml | 8 + common/env/testdata/extended_env.yaml | 11 ++ common/env/testdata/subset_env.yaml | 1 + 12 files changed, 567 insertions(+), 164 deletions(-) diff --git a/checker/decls/decls.go b/checker/decls/decls.go index ef1c4bbb4..e013d2c2b 100644 --- a/checker/decls/decls.go +++ b/checker/decls/decls.go @@ -91,6 +91,17 @@ func NewFunction(name string, Overloads: overloads}}} } +// NewFunctionWithDoc creates a named function declaration with a description and one or more overloads. +func NewFunctionWithDoc(name, doc string, + overloads ...*exprpb.Decl_FunctionDecl_Overload) *exprpb.Decl { + return &exprpb.Decl{ + Name: name, + DeclKind: &exprpb.Decl_Function{ + Function: &exprpb.Decl_FunctionDecl{ + // Doc: desc, + Overloads: overloads}}} +} + // NewIdent creates a named identifier declaration with an optional literal // value. // @@ -98,28 +109,37 @@ func NewFunction(name string, // // Deprecated: Use NewVar or NewConst instead. func NewIdent(name string, t *exprpb.Type, v *exprpb.Constant) *exprpb.Decl { + return newIdent(name, t, v, "") +} + +func newIdent(name string, t *exprpb.Type, v *exprpb.Constant, desc string) *exprpb.Decl { return &exprpb.Decl{ Name: name, DeclKind: &exprpb.Decl_Ident{ Ident: &exprpb.Decl_IdentDecl{ Type: t, - Value: v}}} + Value: v, + Doc: desc}}} } // NewConst creates a constant identifier with a CEL constant literal value. func NewConst(name string, t *exprpb.Type, v *exprpb.Constant) *exprpb.Decl { - return NewIdent(name, t, v) + return newIdent(name, t, v, "") } // NewVar creates a variable identifier. func NewVar(name string, t *exprpb.Type) *exprpb.Decl { - return NewIdent(name, t, nil) + return newIdent(name, t, nil, "") +} + +// NewVarWithDoc creates a variable identifier with a type and a description string. +func NewVarWithDoc(name string, t *exprpb.Type, desc string) *exprpb.Decl { + return newIdent(name, t, nil, desc) } // NewInstanceOverload creates a instance function overload contract. // First element of argTypes is instance. -func NewInstanceOverload(id string, argTypes []*exprpb.Type, - resultType *exprpb.Type) *exprpb.Decl_FunctionDecl_Overload { +func NewInstanceOverload(id string, argTypes []*exprpb.Type, resultType *exprpb.Type) *exprpb.Decl_FunctionDecl_Overload { return &exprpb.Decl_FunctionDecl_Overload{ OverloadId: id, ResultType: resultType, @@ -154,8 +174,7 @@ func NewObjectType(typeName string) *exprpb.Type { // NewOverload creates a function overload declaration which contains a unique // overload id as well as the expected argument and result types. Overloads // must be aggregated within a Function declaration. -func NewOverload(id string, argTypes []*exprpb.Type, - resultType *exprpb.Type) *exprpb.Decl_FunctionDecl_Overload { +func NewOverload(id string, argTypes []*exprpb.Type, resultType *exprpb.Type) *exprpb.Decl_FunctionDecl_Overload { return &exprpb.Decl_FunctionDecl_Overload{ OverloadId: id, ResultType: resultType, diff --git a/common/decls/BUILD.bazel b/common/decls/BUILD.bazel index 17791dce6..bd3f9ae70 100644 --- a/common/decls/BUILD.bazel +++ b/common/decls/BUILD.bazel @@ -13,7 +13,9 @@ go_library( importpath = "github.com/google/cel-go/common/decls", deps = [ "//checker/decls:go_default_library", + "//common:go_default_library", "//common/functions:go_default_library", + "//common/operators:go_default_library", "//common/types:go_default_library", "//common/types/ref:go_default_library", "//common/types/traits:go_default_library", diff --git a/common/decls/decls.go b/common/decls/decls.go index cec22707a..a0fa6bcbd 100644 --- a/common/decls/decls.go +++ b/common/decls/decls.go @@ -20,7 +20,9 @@ import ( "strings" chkdecls "github.com/google/cel-go/checker/decls" + "github.com/google/cel-go/common" "github.com/google/cel-go/common/functions" + "github.com/google/cel-go/common/operators" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" @@ -54,6 +56,7 @@ func NewFunction(name string, opts ...FunctionOpt) (*FunctionDecl, error) { // overload instances. type FunctionDecl struct { name string + doc string // overloads associated with the function name. overloads map[string]*OverloadDecl @@ -84,6 +87,26 @@ const ( declarationEnabled ) +// Documentation generates documentation about the Function and its overloads as a common.Doc object. +func (f *FunctionDecl) Documentation() *common.Doc { + if f == nil { + return nil + } + children := make([]*common.Doc, len(f.OverloadDecls())) + for i, o := range f.OverloadDecls() { + var examples []*common.Doc + for _, ex := range o.Examples() { + examples = append(examples, common.NewExampleDoc(ex)) + } + od := common.NewOverloadDoc(o.ID(), formatSignature(f.Name(), o), examples...) + children[i] = od + } + return common.NewFunctionDoc( + f.Name(), + f.Description(), + children...) +} + // Name returns the function name in human-readable terms, e.g. 'contains' of 'math.least' func (f *FunctionDecl) Name() string { if f == nil { @@ -92,9 +115,22 @@ func (f *FunctionDecl) Name() string { return f.name } +// Description provides an overview of the function's purpose. +// +// Usage examples should be included on specific overloads. +func (f *FunctionDecl) Description() string { + if f == nil { + return "" + } + return f.doc +} + // IsDeclarationDisabled indicates that the function implementation should be added to the dispatcher, but the // declaration should not be exposed for use in expressions. func (f *FunctionDecl) IsDeclarationDisabled() bool { + if f == nil { + return true + } return f.state == declarationDisabled } @@ -107,8 +143,8 @@ func (f *FunctionDecl) Merge(other *FunctionDecl) (*FunctionDecl, error) { if f == other { return f, nil } - if f.Name() != other.Name() { - return nil, fmt.Errorf("cannot merge unrelated functions. %s and %s", f.Name(), other.Name()) + if f == nil || other == nil || f.Name() != other.Name() { + return nil, fmt.Errorf("cannot merge unrelated functions. %q and %q", f.Name(), other.Name()) } merged := &FunctionDecl{ name: f.Name(), @@ -120,12 +156,17 @@ func (f *FunctionDecl) Merge(other *FunctionDecl) (*FunctionDecl, error) { disableTypeGuards: f.disableTypeGuards && other.disableTypeGuards, // default to the current functions declaration state. state: f.state, + doc: f.doc, } // If the other state indicates that the declaration should be explicitly enabled or // disabled, then update the merged state with the most recent value. if other.state != declarationStateUnset { merged.state = other.state } + // Allow for non-empty overrides of documentation + if len(other.doc) != 0 && f.doc != other.doc { + f.doc = other.doc + } // baseline copy of the overloads and their ordinals copy(merged.overloadOrdinals, f.overloadOrdinals) for oID, o := range f.overloads { @@ -202,6 +243,7 @@ func (f *FunctionDecl) Subset(selector OverloadSelector) *FunctionDecl { } subset := &FunctionDecl{ name: f.Name(), + doc: f.doc, overloads: overloads, singleton: f.singleton, disableTypeGuards: f.disableTypeGuards, @@ -218,6 +260,9 @@ func (f *FunctionDecl) AddOverload(overload *OverloadDecl) error { if f == nil { return fmt.Errorf("nil function cannot add overload: %s", overload.ID()) } + if overload == nil { + return fmt.Errorf("cannot add nil overload to funciton: %s", f.Name()) + } for oID, o := range f.overloads { if oID != overload.ID() && o.SignatureOverlaps(overload) { return fmt.Errorf("overload signature collision in function %s: %s collides with %s", f.Name(), oID, overload.ID()) @@ -228,6 +273,10 @@ func (f *FunctionDecl) AddOverload(overload *OverloadDecl) error { if overload.hasBinding() { f.overloads[oID] = overload } + // Allow redefinition of the doc string. + if len(overload.doc) != 0 && o.doc != overload.doc { + o.doc = overload.doc + } return nil } return fmt.Errorf("overload redefinition in function. %s: %s has multiple definitions", f.Name(), oID) @@ -240,8 +289,9 @@ func (f *FunctionDecl) AddOverload(overload *OverloadDecl) error { // OverloadDecls returns the overload declarations in the order in which they were declared. func (f *FunctionDecl) OverloadDecls() []*OverloadDecl { + var emptySet []*OverloadDecl if f == nil { - return []*OverloadDecl{} + return emptySet } overloads := make([]*OverloadDecl, 0, len(f.overloads)) for _, oID := range f.overloadOrdinals { @@ -252,8 +302,9 @@ func (f *FunctionDecl) OverloadDecls() []*OverloadDecl { // Bindings produces a set of function bindings, if any are defined. func (f *FunctionDecl) Bindings() ([]*functions.Overload, error) { + var emptySet []*functions.Overload if f == nil { - return []*functions.Overload{}, nil + return emptySet, nil } overloads := []*functions.Overload{} nonStrict := false @@ -361,6 +412,14 @@ func MaybeNoSuchOverload(funcName string, args ...ref.Val) ref.Val { // FunctionOpt defines a functional option for mutating a function declaration. type FunctionOpt func(*FunctionDecl) (*FunctionDecl, error) +// FunctionDocs configures documentation from a list of strings separated by newlines. +func FunctionDocs(docs ...string) FunctionOpt { + return func(fn *FunctionDecl) (*FunctionDecl, error) { + fn.doc = common.MultilineDescription(docs...) + return fn, nil + } +} + // DisableTypeGuards disables automatically generated function invocation guards on direct overload calls. // Type guards remain on during dynamic dispatch for parsed-only expressions. func DisableTypeGuards(value bool) FunctionOpt { @@ -513,6 +572,7 @@ func newOverloadInternal(overloadID string, // implementation. type OverloadDecl struct { id string + doc string argTypes []*types.Type resultType *types.Type isMemberFunction bool @@ -532,6 +592,15 @@ type OverloadDecl struct { functionOp functions.FunctionOp } +// Examples returns a list of string examples for the overload. +func (o *OverloadDecl) Examples() []string { + var emptySet []string + if o == nil || len(o.doc) == 0 { + return emptySet + } + return common.ParseDescriptions(o.doc) +} + // ID mirrors the overload signature and provides a unique id which may be referenced within the type-checker // and interpreter to optimize performance. // @@ -729,6 +798,14 @@ func matchOperandTrait(trait int, arg ref.Val) bool { // OverloadOpt is a functional option for configuring a function overload. type OverloadOpt func(*OverloadDecl) (*OverloadDecl, error) +// OverloadExamples configures example expressions for the overload. +func OverloadExamples(examples ...string) OverloadOpt { + return func(o *OverloadDecl) (*OverloadDecl, error) { + o.doc = common.MultilineDescription(examples...) + return o, nil + } +} + // UnaryBinding provides the implementation of a unary overload. The provided function is protected by a runtime // type-guard which ensures runtime type agreement between the overload signature and runtime argument types. func UnaryBinding(binding functions.UnaryOp) OverloadOpt { @@ -800,13 +877,27 @@ func NewVariable(name string, t *types.Type) *VariableDecl { return &VariableDecl{name: name, varType: t} } +// NewVariableWithDoc creates a new variable declaration with usage documentation. +func NewVariableWithDoc(name string, t *types.Type, doc string) *VariableDecl { + return &VariableDecl{name: name, varType: t, doc: doc} +} + // VariableDecl defines a variable declaration which may optionally have a constant value. type VariableDecl struct { name string + doc string varType *types.Type value ref.Val } +// Documentation returns name, type, and description for the variable. +func (v *VariableDecl) Documentation() *common.Doc { + if v == nil { + return nil + } + return common.NewVariableDoc(v.Name(), describeCELType(v.Type()), v.Description()) +} + // Name returns the fully-qualified variable name func (v *VariableDecl) Name() string { if v == nil { @@ -815,6 +906,16 @@ func (v *VariableDecl) Name() string { return v.name } +// Description returns the usage documentation for the variable, if set. +// +// Good usage instructions provide information about the valid formats, ranges, sizes for the variable type. +func (v *VariableDecl) Description() string { + if v == nil { + return "" + } + return v.doc +} + // Type returns the types.Type value associated with the variable. func (v *VariableDecl) Type() *types.Type { if v == nil { @@ -856,7 +957,7 @@ func variableDeclToExprDecl(v *VariableDecl) (*exprpb.Decl, error) { if err != nil { return nil, err } - return chkdecls.NewVar(v.Name(), varType), nil + return chkdecls.NewVarWithDoc(v.Name(), varType, v.doc), nil } // FunctionDeclToExprDecl converts a go-native function declaration into a protobuf-typed function declaration. @@ -901,8 +1002,10 @@ func functionDeclToExprDecl(f *FunctionDecl) (*exprpb.Decl, error) { overloads[i] = chkdecls.NewParameterizedOverload(oID, argTypes, resultType, params) } } + doc := common.MultilineDescription(o.Examples()...) + overloads[i].Doc = doc } - return chkdecls.NewFunction(f.Name(), overloads...), nil + return chkdecls.NewFunctionWithDoc(f.Name(), f.Description(), overloads...), nil } func collectParamNames(paramNames map[string]struct{}, arg *types.Type) { @@ -914,6 +1017,60 @@ func collectParamNames(paramNames map[string]struct{}, arg *types.Type) { } } +func formatSignature(fnName string, o *OverloadDecl) string { + if opName, isOperator := operators.FindReverse(fnName); isOperator { + if opName == "" { + opName = fnName + } + return formatOperator(opName, o) + } + return formatCall(fnName, o) +} + +func formatOperator(opName string, o *OverloadDecl) string { + args := o.ArgTypes() + argTypes := make([]string, len(o.ArgTypes())) + for j, a := range args { + argTypes[j] = describeCELType(a) + } + ret := describeCELType(o.ResultType()) + switch len(args) { + case 1: + return fmt.Sprintf("%s%s -> %s", opName, argTypes[0], ret) + case 2: + if opName == operators.Index { + return fmt.Sprintf("%s[%s] -> %s", argTypes[0], argTypes[1], ret) + } + return fmt.Sprintf("%s %s %s -> %s", argTypes[0], opName, argTypes[1], ret) + default: + if opName == operators.Conditional { + return fmt.Sprint("bool ? : -> ") + } + return formatCall(opName, o) + } +} + +func formatCall(funcName string, o *OverloadDecl) string { + args := make([]string, len(o.ArgTypes())) + ret := describeCELType(o.ResultType()) + for j, a := range o.ArgTypes() { + args[j] = describeCELType(a) + } + if o.IsMemberFunction() { + target := args[0] + args = args[1:] + return fmt.Sprintf("%s.%s(%s) -> %s", target, funcName, strings.Join(args, ", "), ret) + } + return fmt.Sprintf("%s(%s) -> %s", funcName, strings.Join(args, ", "), ret) +} + +func describeCELType(t *types.Type) string { + if t.Kind() == types.TypeKind { + return "type" + } + return t.String() +} + var ( - emptyArgs = []*types.Type{} + emptyArgs []*types.Type ) diff --git a/common/decls/decls_test.go b/common/decls/decls_test.go index e93966007..262ef355d 100644 --- a/common/decls/decls_test.go +++ b/common/decls/decls_test.go @@ -23,6 +23,9 @@ import ( "google.golang.org/protobuf/proto" chkdecls "github.com/google/cel-go/checker/decls" + "github.com/google/cel-go/common" + "github.com/google/cel-go/common/operators" + "github.com/google/cel-go/common/overloads" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/common/types/traits" @@ -242,8 +245,141 @@ func TestFunctionSingletonBinding(t *testing.T) { } } +func TestVariableDocumentation(t *testing.T) { + v := NewVariableWithDoc("var", types.StringType, "string variable") + doc := v.Documentation() + if doc.Description != v.Description() { + t.Errorf("doc.Description got %s, wanted %s", doc.Description, v.Description()) + } + if doc.Name != v.Name() { + t.Errorf("doc.Name got %s, wanted %s", doc.Name, v.Name()) + } + if doc.Type != "string" { + t.Errorf("doc.Type got %s, wanted string", doc.Type) + } +} + +func TestFunctionDocumentation(t *testing.T) { + tests := []struct { + name string + fn *FunctionDecl + signatures []string + examples []string + }{ + { + name: "function", + fn: testFunction(t, "size", + FunctionDocs(`compute the number of entries in a list or map`), + MemberOverload("list_size", + []*types.Type{types.NewListType(types.NewTypeParamType("T"))}, types.IntType, + OverloadExamples(`[].size() // 0`, `[1, 2, 3].size() // 3`)), + Overload("size_list", + []*types.Type{types.NewListType(types.NewTypeParamType("T"))}, types.IntType, + OverloadExamples(`size([]) // 0`, `size([1, 2, 3]) // 3`))), + signatures: []string{ + "list().size() -> int", + "size(list()) -> int", + }, + examples: []string{ + `[].size() // 0`, + `[1, 2, 3].size() // 3`, + `size([]) // 0`, + `size([1, 2, 3]) // 3`, + }, + }, + { + name: "type", + fn: testFunction(t, overloads.TypeConvertType, + Overload(overloads.TypeConvertType, + []*types.Type{types.NewTypeTypeWithParam(types.NewTypeParamType("T"))}, + types.TypeType, + OverloadExamples(`type(int) // type`))), + signatures: []string{"type(type) -> type"}, + examples: []string{`type(int) // type`}, + }, + { + name: "unary operator", + fn: testFunction(t, operators.Negate, + FunctionDocs(`negate a numeric value`), + Overload(overloads.NegateInt64, []*types.Type{types.IntType}, types.IntType, + OverloadExamples(`-(1) // -1`))), + signatures: []string{"-int -> int"}, + examples: []string{`-(1) // -1`}, + }, + { + name: "binary operator", + fn: testFunction(t, operators.Add, + FunctionDocs(`add two numeric values`), + Overload(overloads.AddInt64, []*types.Type{types.IntType, types.IntType}, types.IntType, + OverloadExamples(`1 + 2 // 3`))), + signatures: []string{"int + int -> int"}, + examples: []string{`1 + 2 // 3`}, + }, + { + name: "index operator", + fn: testFunction(t, operators.Index, + FunctionDocs(`access a list by numeric index, zero-based`), + Overload(overloads.IndexList, []*types.Type{types.NewListType(types.NewTypeParamType("T")), types.IntType}, types.NewTypeParamType("T"), + OverloadExamples(`[1, 2, 3, 4][2] // 3`))), + signatures: []string{"list()[int] -> "}, + examples: []string{`[1, 2, 3, 4][2] // 3`}, + }, + { + name: "conditional", + fn: testFunction(t, operators.Conditional, + FunctionDocs(`ternary operator`), + Overload(overloads.Conditional, + []*types.Type{types.BoolType, types.NewTypeParamType("T"), types.NewTypeParamType("T")}, + types.NewTypeParamType("T"), + OverloadExamples(`true ? 1 : 2 // 1`))), + signatures: []string{"bool ? : -> "}, + examples: []string{`true ? 1 : 2 // 1`}, + }, + } + for _, tst := range tests { + tc := tst + t.Run(tc.name, func(t *testing.T) { + doc := tc.fn.Documentation() + if doc.Kind != common.DocFunction { + t.Errorf("fn.Documentation() kind got %v, wanted common.DocFunction", doc.Kind) + } + if doc.Description != tc.fn.Description() { + t.Errorf("doc.Description got %s, wanted %s", doc.Description, tc.fn.Description()) + } + if len(doc.Children) != len(tc.fn.OverloadDecls()) { + t.Fatalf("doc.Children count was %d, wanted %d", len(doc.Children), len(tc.fn.OverloadDecls())) + } + for _, child := range doc.Children { + found := false + for _, sig := range tc.signatures { + if child.Signature == sig { + found = true + break + } + } + if !found { + t.Errorf("unable to find signature: %s", child.Signature) + } + for _, childEx := range child.Children { + found = false + for _, ex := range tc.examples { + if strings.Contains(childEx.Description, ex) { + found = true + break + } + } + if !found { + t.Errorf("unable to find example: %v", childEx.Description) + } + } + } + }) + } +} + func TestFunctionMerge(t *testing.T) { sizeFunc, err := NewFunction("size", + FunctionDocs(`compute the number of entries in a list or map`), MemberOverload("list_size", []*types.Type{types.NewListType(types.NewTypeParamType("T"))}, types.IntType), MemberOverload("map_size", @@ -279,6 +415,10 @@ func TestFunctionMerge(t *testing.T) { if len(sizeMerged.overloads) != 3 { t.Errorf("Merge() produced %d overloads, wanted 3", len(sizeFunc.overloads)) } + if sizeMerged.Description() != "compute the number of entries in a list or map" { + t.Errorf("Description() got %s, wanted %s", sizeMerged.Description(), + "compute the number of entries in a list or map") + } overloads := map[string]bool{ "list_size": true, "map_size": true, @@ -934,12 +1074,18 @@ func TestFunctionDeclToExprDecl(t *testing.T) { { fn: testMerge(t, testFunction(t, "equals", - Overload("int_equals_uint", []*types.Type{types.IntType, types.UintType}, types.BoolType), - Overload("uint_equals_int", []*types.Type{types.UintType, types.IntType}, types.BoolType)), + FunctionDocs(`test equality between an int and uint only`), + Overload("int_equals_uint", []*types.Type{types.IntType, types.UintType}, types.BoolType, + OverloadExamples(`1 == 1u // true`)), + Overload("uint_equals_int", []*types.Type{types.UintType, types.IntType}, types.BoolType, + OverloadExamples(`1u == -1 // false`))), testFunction(t, "equals", - Overload("int_equals_int", []*types.Type{types.IntType, types.IntType}, types.BoolType), + FunctionDocs(`test equality between two int-like values`), + Overload("int_equals_int", []*types.Type{types.IntType, types.IntType}, types.BoolType, + OverloadExamples(`1 == 1 // true`, `1 == 2 // false`)), Overload("int_equals_uint", []*types.Type{types.IntType, types.UintType}, types.BoolType), - Overload("uint_equals_uint", []*types.Type{types.UintType, types.UintType}, types.BoolType))), + Overload("uint_equals_uint", []*types.Type{types.UintType, types.UintType}, types.BoolType, + OverloadExamples(`1u == 1u // true`)))), exDecl: &exprpb.Decl{ Name: "equals", DeclKind: &exprpb.Decl_Function{ @@ -952,6 +1098,7 @@ func TestFunctionDeclToExprDecl(t *testing.T) { chkdecls.Uint, }, ResultType: chkdecls.Bool, + Doc: "1 == 1u // true", }, { OverloadId: "uint_equals_int", @@ -960,6 +1107,7 @@ func TestFunctionDeclToExprDecl(t *testing.T) { chkdecls.Int, }, ResultType: chkdecls.Bool, + Doc: "1u == -1 // false", }, { OverloadId: "int_equals_int", @@ -968,6 +1116,7 @@ func TestFunctionDeclToExprDecl(t *testing.T) { chkdecls.Int, }, ResultType: chkdecls.Bool, + Doc: "1 == 1 // true\n1 == 2 // false", }, { OverloadId: "uint_equals_uint", @@ -976,6 +1125,7 @@ func TestFunctionDeclToExprDecl(t *testing.T) { chkdecls.Uint, }, ResultType: chkdecls.Bool, + Doc: "1u == 1u // true", }, }, }, @@ -1056,6 +1206,17 @@ func TestVariableDeclToExprDecl(t *testing.T) { t.Error("proto.Equal() returned false, wanted true") } + docVar := NewVariableWithDoc("a", types.BoolType, "doc") + a, err = VariableDeclToExprDecl(docVar) + if err != nil { + t.Fatalf("VariableDeclToExprDecl() failed: %v", err) + } + if docVar.Description() != a.GetIdent().GetDoc() { + t.Errorf("docVar.Description() got %s, wanted %s", docVar.Description(), a.GetIdent().GetDoc()) + } + if !proto.Equal(a, chkdecls.NewVarWithDoc("a", chkdecls.Bool, "doc")) { + t.Error("proto.Equal() returned false, wanted true") + } } func TestVariableDeclToExprDeclInvalid(t *testing.T) { @@ -1065,6 +1226,57 @@ func TestVariableDeclToExprDeclInvalid(t *testing.T) { } } +func TestNilFunction(t *testing.T) { + var f *FunctionDecl + if f.Name() != "" { + t.Errorf("f.Name() got %s, wanted ''", f.Name()) + } + if f.Description() != "" { + t.Errorf("f.Description() got %s, wanted ''", f.Description()) + } + if f.Documentation() != nil { + t.Errorf("f.Documentation() got %v, wanted nil", f.Documentation()) + } + if !f.IsDeclarationDisabled() { + t.Errorf("f.IsDeclarationDisabled() got %t, wanted true", f.IsDeclarationDisabled()) + } + if len(f.OverloadDecls()) != 0 { + t.Errorf("f.OverloadDecls() got %d overloads, wanted 0", len(f.OverloadDecls())) + } + if b, err := f.Bindings(); err != nil || len(b) != 0 { + t.Error("f.Bindings() got non-empty result") + } + other := &FunctionDecl{} + if fn, err := f.Merge(other); err == nil { + t.Errorf("f.Merge(nil) wanted error, got %v", fn) + } + if err := f.AddOverload(nil); err == nil { + t.Error("f.AddOverload(nil) did not error") + } + if err := other.AddOverload(nil); err == nil { + t.Error("f.AddOverload(nil) did not error") + } +} + +func TestNilVariable(t *testing.T) { + var v *VariableDecl + if v.Name() != "" { + t.Errorf("v.Name() got %s, wanted ''", v.Name()) + } + if v.Type() != nil { + t.Errorf("v.Type() got %v, wanted nil", v.Type()) + } + if v.Value() != nil { + t.Errorf("v.Type() got %v, wanted nil", v.Value()) + } + if v.Description() != "" { + t.Errorf("v.Description() got %s, wanted ''", v.Description()) + } + if v.Documentation() != nil { + t.Errorf("v.Documentation() got %v, wanted nil", v.Documentation()) + } +} + func testMerge(t *testing.T, funcs ...*FunctionDecl) *FunctionDecl { t.Helper() fn := funcs[0] diff --git a/common/doc.go b/common/doc.go index c68df818f..06eae3642 100644 --- a/common/doc.go +++ b/common/doc.go @@ -17,7 +17,6 @@ package common import ( - "fmt" "strings" "unicode" ) @@ -40,10 +39,6 @@ const ( DocExample ) -// MultilineDescription represents a description that can span multiple lines, -// stored as a slice of strings. -type MultilineDescription []string - // Doc holds the documentation details for a specific program element like // a variable, function, macro, or example. type Doc struct { @@ -60,59 +55,24 @@ type Doc struct { Signature string // Description holds the textual description of the element, potentially spanning multiple lines. - Description MultilineDescription + Description string // Children holds nested documentation elements, such as overloads for a function // or examples for a function/macro. Children []*Doc } -// FormatDescription joins multiple description elements (string, MultilineDescription, -// or []MultilineDescription) into a single string, separated by double newlines ("\n\n"). -// It returns the formatted string or an error if an unsupported type is encountered. -func FormatDescription(descs ...any) (string, error) { - return FormatDescriptionSeparator("\n\n", descs...) -} - -// FormatDescriptionSeparator joins multiple description elements (string, MultilineDescription, -// or []MultilineDescription) into a single string using the specified separator. -// It returns the formatted string or an error if an unsupported description type is passed. -func FormatDescriptionSeparator(sep string, descs ...any) (string, error) { - var builder strings.Builder - hasDoc := false - for _, d := range descs { - if hasDoc { - builder.WriteString(sep) - } - switch v := d.(type) { - case string: - builder.WriteString(v) - case MultilineDescription: - str := strings.Join(v, "\n") - builder.WriteString(str) - case []MultilineDescription: - for _, md := range v { - if hasDoc { - builder.WriteString(sep) - } - str := strings.Join(md, "\n") - builder.WriteString(str) - hasDoc = true - } - default: - return "", fmt.Errorf("unsupported description type: %T", d) - } - hasDoc = true - } - return builder.String(), nil +// MultilineDescription combines multiple lines into a newline separated string. +func MultilineDescription(lines ...string) string { + return strings.Join(lines, "\n") } // ParseDescription takes a single string containing newline characters and splits -// it into a MultilineDescription. All empty lines will be skipped. +// it into a multiline description. All empty lines will be skipped. // -// Returns an empty MultilineDescription if the input string is empty. -func ParseDescription(doc string) MultilineDescription { - var lines MultilineDescription +// Returns an empty string if the input string is empty. +func ParseDescription(doc string) string { + var lines []string if len(doc) != 0 { // Split the input string by newline characters. for _, line := range strings.Split(doc, "\n") { @@ -124,13 +84,13 @@ func ParseDescription(doc string) MultilineDescription { } } // Return an empty slice if the input is empty. - return lines + return MultilineDescription(lines...) } -// ParseDescriptions splits a documentation string into multiple MultilineDescription +// ParseDescriptions splits a documentation string into multiple multi-line description // sections, using blank lines as delimiters. -func ParseDescriptions(doc string) []MultilineDescription { - var examples []MultilineDescription +func ParseDescriptions(doc string) []string { + var examples []string if len(doc) != 0 { lines := strings.Split(doc, "\n") lineStart := 0 @@ -142,14 +102,14 @@ func ParseDescriptions(doc string) []MultilineDescription { // Start the next section after the blank line. ex := lines[lineStart:i] if len(ex) != 0 { - examples = append(examples, ex) + examples = append(examples, MultilineDescription(ex...)) } lineStart = i + 1 } } // Append the last section if it wasn't terminated by a blank line. if lineStart < len(lines) { - examples = append(examples, lines[lineStart:]) + examples = append(examples, MultilineDescription(lines[lineStart:]...)) } } return examples @@ -196,7 +156,7 @@ func NewMacroDoc(name, description string, examples ...*Doc) *Doc { } // NewExampleDoc creates a new Doc struct specifically for holding an example. -func NewExampleDoc(ex MultilineDescription) *Doc { +func NewExampleDoc(ex string) *Doc { return &Doc{ Kind: DocExample, Description: ex, diff --git a/common/doc_test.go b/common/doc_test.go index 2f06b3d8c..c84c3678a 100644 --- a/common/doc_test.go +++ b/common/doc_test.go @@ -15,63 +15,15 @@ package common import ( - "errors" "reflect" - "strings" "testing" ) -func TestFormatDescription(t *testing.T) { - tests := []struct { - name string - in []any - out string - err error - }{ - { - name: "two separate examples as strings", - in: []any{"hello", "world"}, - out: "hello\n\nworld", - }, - { - name: "single example as multiline string", - in: []any{MultilineDescription{"hello", "world"}}, - out: "hello\nworld", - }, - { - name: "two examples as a list of multiline strings", - in: []any{[]MultilineDescription{{"hello", "world"}, {"goodbye", "cruel world"}}}, - out: "hello\nworld\n\ngoodbye\ncruel world", - }, - { - name: "invalid description", - in: []any{1}, - err: errors.New("unsupported description type"), - }, - } - - for _, tst := range tests { - tc := tst - t.Run(tc.name, func(t *testing.T) { - out, err := FormatDescription(tc.in...) - if err != nil { - if tc.err == nil || !strings.Contains(err.Error(), tc.err.Error()) { - t.Fatalf("FormatDescription() errored with %v, wanted %v", err, tc) - } - return - } - if out != tc.out { - t.Errorf("FormatDescription() got %s, wanted %v", out, tc) - } - }) - } -} - func TestParseDescription(t *testing.T) { tests := []struct { name string in string - out MultilineDescription + out string }{ { name: "empty", @@ -79,12 +31,12 @@ func TestParseDescription(t *testing.T) { { name: "single", in: "hello", - out: MultilineDescription{"hello"}, + out: "hello", }, { name: "multi", in: "hello\n\n\nworld", - out: MultilineDescription{"hello", "world"}, + out: "hello\nworld", }, } @@ -103,7 +55,7 @@ func TestParseDescriptions(t *testing.T) { tests := []struct { name string in string - out []MultilineDescription + out []string }{ { name: "empty", @@ -111,17 +63,17 @@ func TestParseDescriptions(t *testing.T) { { name: "single", in: "hello", - out: []MultilineDescription{{"hello"}}, + out: []string{"hello"}, }, { name: "multi", in: "bar\nbaz\n\nfoo", - out: []MultilineDescription{{"bar", "baz"}, {"foo"}}, + out: []string{"bar\nbaz", "foo"}, }, { name: "multi", in: "hello\n\n\nworld", - out: []MultilineDescription{{"hello"}, {"world"}}, + out: []string{"hello", "world"}, }, } @@ -143,17 +95,17 @@ func TestNewDoc(t *testing.T) { name string celType string sig string - desc MultilineDescription + desc string childCount int }{ { newDoc: func() *Doc { return NewMacroDoc("map", "map converts a list or map of values to a list", - NewExampleDoc(MultilineDescription{"[1, 2].map(i, i * 2) // [2, 4]"})) + NewExampleDoc("[1, 2].map(i, i * 2) // [2, 4]")) }, kind: DocMacro, name: "map", - desc: MultilineDescription{"map converts a list or map of values to a list"}, + desc: "map converts a list or map of values to a list", childCount: 1, }, { @@ -166,7 +118,7 @@ func TestNewDoc(t *testing.T) { kind: DocVariable, name: "request", celType: "google.rpc.context.AttributeContext.Request", - desc: MultilineDescription{"parameters related to an HTTP API request"}, + desc: "parameters related to an HTTP API request", childCount: 0, }, { @@ -174,11 +126,11 @@ func TestNewDoc(t *testing.T) { return NewFunctionDoc("getToken", "get the JWT token from a request\nas deserialized JSON", NewOverloadDoc("request_getToken", "request.getToken() -> map(string, dyn)", - NewExampleDoc(MultilineDescription{"has(request.getToken().sub) // false"}))) + NewExampleDoc("has(request.getToken().sub) // false"))) }, kind: DocFunction, name: "getToken", - desc: MultilineDescription{"get the JWT token from a request", "as deserialized JSON"}, + desc: "get the JWT token from a request\nas deserialized JSON", childCount: 1, }, } diff --git a/common/env/BUILD.bazel b/common/env/BUILD.bazel index 0e7dae1d1..aebe1e544 100644 --- a/common/env/BUILD.bazel +++ b/common/env/BUILD.bazel @@ -26,6 +26,7 @@ go_library( ], importpath = "github.com/google/cel-go/common/env", deps = [ + "//common:go_default_library", "//common/decls:go_default_library", "//common/types:go_default_library", ], diff --git a/common/env/env.go b/common/env/env.go index 8e57c42e3..4f2bebade 100644 --- a/common/env/env.go +++ b/common/env/env.go @@ -115,7 +115,9 @@ func (c *Config) AddVariableDecls(vars ...*decls.VariableDecl) *Config { if v == nil { continue } - convVars[i] = NewVariable(v.Name(), SerializeTypeDesc(v.Type())) + cv := NewVariable(v.Name(), SerializeTypeDesc(v.Type())) + cv.Description = v.Description() + convVars[i] = cv } return c.AddVariables(convVars...) } @@ -149,13 +151,21 @@ func (c *Config) AddFunctionDecls(funcs ...*decls.FunctionDecl) *Config { args = append(args, SerializeTypeDesc(a)) } ret := SerializeTypeDesc(o.ResultType()) + var overload *Overload if o.IsMemberFunction() { - overloads = append(overloads, NewMemberOverload(overloadID, args[0], args[1:], ret)) + overload = NewMemberOverload(overloadID, args[0], args[1:], ret) } else { - overloads = append(overloads, NewOverload(overloadID, args, ret)) + overload = NewOverload(overloadID, args, ret) } + exampleCount := len(o.Examples()) + if exampleCount > 0 { + overload.Examples = o.Examples() + } + overloads = append(overloads, overload) } - convFuncs[i] = NewFunction(fn.Name(), overloads...) + cf := NewFunction(fn.Name(), overloads...) + cf.Description = fn.Description() + convFuncs[i] = cf } return c.AddFunctions(convFuncs...) } @@ -220,7 +230,12 @@ func (imp *Import) Validate() error { // NewVariable returns a serializable variable from a name and type definition func NewVariable(name string, t *TypeDesc) *Variable { - return &Variable{Name: name, TypeDesc: t} + return NewVariableWithDoc(name, t, "") +} + +// NewVariable returns a serializable variable from a name, type definition, and doc string. +func NewVariableWithDoc(name string, t *TypeDesc, doc string) *Variable { + return &Variable{Name: name, TypeDesc: t, Description: doc} } // Variable represents a typed variable declaration which will be published via the @@ -278,7 +293,7 @@ func (v *Variable) AsCELVariable(tp types.Provider) (*decls.VariableDecl, error) if err != nil { return nil, fmt.Errorf("invalid variable %q: %w", v.Name, err) } - return decls.NewVariable(v.Name, t), nil + return decls.NewVariableWithDoc(v.Name, t, v.Description), nil } // NewContextVariable returns a serializable context variable with a specific type name. @@ -310,6 +325,11 @@ func NewFunction(name string, overloads ...*Overload) *Function { return &Function{Name: name, Overloads: overloads} } +// NewFunctionWithDoc creates a serializable function and overload set. +func NewFunctionWithDoc(name, doc string, overloads ...*Overload) *Function { + return &Function{Name: name, Description: doc, Overloads: overloads} +} + // Function represents the serializable format of a function and its overloads. type Function struct { Name string `yaml:"name"` @@ -342,34 +362,37 @@ func (fn *Function) AsCELFunction(tp types.Provider) (*decls.FunctionDecl, error if err := fn.Validate(); err != nil { return nil, err } - var err error - overloads := make([]decls.FunctionOpt, len(fn.Overloads)) - for i, o := range fn.Overloads { - overloads[i], err = o.AsFunctionOption(tp) + opts := make([]decls.FunctionOpt, 0, len(fn.Overloads)+1) + for _, o := range fn.Overloads { + opt, err := o.AsFunctionOption(tp) + opts = append(opts, opt) if err != nil { return nil, fmt.Errorf("invalid function %q: %w", fn.Name, err) } } - return decls.NewFunction(fn.Name, overloads...) + if len(fn.Description) != 0 { + opts = append(opts, decls.FunctionDocs(fn.Description)) + } + return decls.NewFunction(fn.Name, opts...) } // NewOverload returns a new serializable representation of a global overload. -func NewOverload(id string, args []*TypeDesc, ret *TypeDesc) *Overload { - return &Overload{ID: id, Args: args, Return: ret} +func NewOverload(id string, args []*TypeDesc, ret *TypeDesc, examples ...string) *Overload { + return &Overload{ID: id, Args: args, Return: ret, Examples: examples} } // NewMemberOverload returns a new serializable representation of a member (receiver) overload. -func NewMemberOverload(id string, target *TypeDesc, args []*TypeDesc, ret *TypeDesc) *Overload { - return &Overload{ID: id, Target: target, Args: args, Return: ret} +func NewMemberOverload(id string, target *TypeDesc, args []*TypeDesc, ret *TypeDesc, examples ...string) *Overload { + return &Overload{ID: id, Target: target, Args: args, Return: ret, Examples: examples} } // Overload represents the serializable format of a function overload. type Overload struct { - ID string `yaml:"id"` - Description string `yaml:"description,omitempty"` - Target *TypeDesc `yaml:"target,omitempty"` - Args []*TypeDesc `yaml:"args,omitempty"` - Return *TypeDesc `yaml:"return,omitempty"` + ID string `yaml:"id"` + Examples []string `yaml:"examples,omitempty"` + Target *TypeDesc `yaml:"target,omitempty"` + Args []*TypeDesc `yaml:"args,omitempty"` + Return *TypeDesc `yaml:"return,omitempty"` } // Validate validates the overload configuration is well-formed. @@ -426,7 +449,7 @@ func (od *Overload) AsFunctionOption(tp types.Provider) (decls.FunctionOpt, erro if len(errs) != 0 { return nil, errors.Join(errs...) } - return decls.Overload(od.ID, args, result), nil + return decls.Overload(od.ID, args, result, decls.OverloadExamples(od.Examples...)), nil } // NewExtension creates a serializable Extension from a name and version string. diff --git a/common/env/env_test.go b/common/env/env_test.go index 2e2cd1d86..cc9be0334 100644 --- a/common/env/env_test.go +++ b/common/env/env_test.go @@ -25,6 +25,7 @@ import ( "gopkg.in/yaml.v3" + "github.com/google/cel-go/common" "github.com/google/cel-go/common/decls" "github.com/google/cel-go/common/operators" "github.com/google/cel-go/common/overloads" @@ -55,16 +56,21 @@ func TestConfig(t *testing.T) { AddExtensions(NewExtension("optional", math.MaxUint32), NewExtension("strings", 1)). SetContextVariable(NewContextVariable("google.expr.proto3.test.TestAllTypes")). AddFunctions( - NewFunction("coalesce", + NewFunctionWithDoc("coalesce", + "Converts a potentially null wrapper-type to a default value.", NewOverload("coalesce_wrapped_int", []*TypeDesc{NewTypeDesc("google.protobuf.Int64Value"), NewTypeDesc("int")}, - NewTypeDesc("int")), + NewTypeDesc("int"), + `coalesce(null, 1) // 1`, + `coalesce(2, 1) // 2`), NewOverload("coalesce_wrapped_double", []*TypeDesc{NewTypeDesc("google.protobuf.DoubleValue"), NewTypeDesc("double")}, - NewTypeDesc("double")), + NewTypeDesc("double"), + `coalesce(null, 1.3) // 1.3`), NewOverload("coalesce_wrapped_uint", []*TypeDesc{NewTypeDesc("google.protobuf.UInt64Value"), NewTypeDesc("uint")}, - NewTypeDesc("uint")), + NewTypeDesc("uint"), + `coalesce(null, 14u) // 14u`), ), ), }, @@ -76,15 +82,23 @@ func TestConfig(t *testing.T) { NewExtension("optional", 2), NewExtension("math", math.MaxUint32), ).AddVariables( - NewVariable("msg", NewTypeDesc("google.expr.proto3.test.TestAllTypes")), + NewVariableWithDoc("msg", + NewTypeDesc("google.expr.proto3.test.TestAllTypes"), + `msg represents all possible type permutation which CEL understands from a proto perspective`), ).AddFunctions( - NewFunction("isEmpty", + NewFunctionWithDoc("isEmpty", + common.MultilineDescription( + `determines whether a list is empty,`, + `or a string has no characters`), NewMemberOverload("wrapper_string_isEmpty", NewTypeDesc("google.protobuf.StringValue"), nil, - NewTypeDesc("bool")), + NewTypeDesc("bool"), + `''.isEmpty() // true`), NewMemberOverload("list_isEmpty", NewTypeDesc("list", NewTypeParam("T")), nil, - NewTypeDesc("bool")), + NewTypeDesc("bool"), + `[].isEmpty() // true`, + `[1].isEmpty() // false`), ), ).AddFeatures( NewFeature("cel.feature.macro_call_tracking", true), @@ -153,7 +167,7 @@ func TestConfig(t *testing.T) { for i, v := range got.Variables { wv := tc.want.Variables[i] if !reflect.DeepEqual(v, wv) { - t.Errorf("Variables[%d] not equal, got %v, wanted %v", i, v, wv) + t.Errorf("Variables[%d] not equal, got %v, wanted %v", i, v.Description, wv.Description) } } } @@ -165,12 +179,16 @@ func TestConfig(t *testing.T) { if f.Name != wf.Name { t.Errorf("Functions[%d] not equal, got %v, wanted %v", i, f.Name, wf.Name) } + if f.Description != wf.Description { + t.Errorf("Functions[%d] description not equal, got %v, wanted %v", i, f.Description, wf.Description) + } if len(f.Overloads) != len(wf.Overloads) { t.Errorf("Function %s got overload count: %d, wanted %d", f.Name, len(f.Overloads), len(wf.Overloads)) } for j, o := range f.Overloads { wo := wf.Overloads[j] if !reflect.DeepEqual(o, wo) { + t.Error(o.Examples) t.Errorf("Overload[%d] got %v, wanted %v", j, o, wo) } } @@ -313,6 +331,11 @@ func TestConfigAddVariableDecls(t *testing.T) { in: decls.NewVariable("var", types.NewObjectType("google.type.Expr")), out: NewVariable("var", NewTypeDesc("google.type.Expr")), }, + { + name: "proto var decl with doc", + in: decls.NewVariableWithDoc("var", types.NewObjectType("google.type.Expr"), "API-friendly CEL expression type"), + out: NewVariableWithDoc("var", NewTypeDesc("google.type.Expr"), "API-friendly CEL expression type"), + }, } for _, tst := range tests { tc := tst @@ -372,6 +395,19 @@ func TestConfigAddFunctionDecls(t *testing.T) { NewMemberOverload("string_size", NewTypeDesc("string"), []*TypeDesc{}, NewTypeDesc("int")), ), }, + { + name: "global function decl - with doc", + in: mustNewFunction(t, "size", + decls.FunctionDocs("return the number of unicode code points", "in a string"), + decls.Overload("size_string", []*types.Type{types.StringType}, types.IntType, + decls.OverloadExamples(`'hello'.size() // 5`)), + ), + out: NewFunctionWithDoc("size", + "return the number of unicode code points\nin a string", + NewOverload("size_string", []*TypeDesc{NewTypeDesc("string")}, NewTypeDesc("int"), + `'hello'.size() // 5`), + ), + }, } for _, tst := range tests { tc := tst @@ -729,6 +765,27 @@ func TestFunctionAsCELFunction(t *testing.T) { }, want: mustNewFunction(t, "size", decls.MemberOverload("string_size", []*types.Type{types.StringType}, types.IntType)), }, + { + name: "member function", + f: &Function{Name: "size", + Description: "return the number of unicode code points in a string", + Overloads: []*Overload{{ + ID: "string_size", + Target: &TypeDesc{TypeName: "string"}, + Return: &TypeDesc{TypeName: "int"}, + Examples: []string{ + `'hello'.size() // 5`, + `'hello world'.size() // 11`, + }, + }}, + }, + want: mustNewFunction(t, "size", + decls.FunctionDocs("return the number of unicode code points in a string"), + decls.MemberOverload("string_size", []*types.Type{types.StringType}, types.IntType, + decls.OverloadExamples( + `'hello'.size() // 5`, + `'hello world'.size() // 11`))), + }, } tp, err := types.NewRegistry() if err != nil { diff --git a/common/env/testdata/context_env.yaml b/common/env/testdata/context_env.yaml index 348d2d723..9b0ca9188 100644 --- a/common/env/testdata/context_env.yaml +++ b/common/env/testdata/context_env.yaml @@ -36,20 +36,28 @@ context_variable: type_name: "google.expr.proto3.test.TestAllTypes" functions: - name: "coalesce" + description: "Converts a potentially null wrapper-type to a default value." overloads: - id: "coalesce_wrapped_int" + examples: + - "coalesce(null, 1) // 1" + - "coalesce(2, 1) // 2" args: - type_name: "google.protobuf.Int64Value" - type_name: "int" return: type_name: "int" - id: "coalesce_wrapped_double" + examples: + - "coalesce(null, 1.3) // 1.3" args: - type_name: "google.protobuf.DoubleValue" - type_name: "double" return: type_name: "double" - id: "coalesce_wrapped_uint" + examples: + - "coalesce(null, 14u) // 14u" args: - type_name: "google.protobuf.UInt64Value" - type_name: "uint" diff --git a/common/env/testdata/extended_env.yaml b/common/env/testdata/extended_env.yaml index 808e0a89e..5e9e6fcda 100644 --- a/common/env/testdata/extended_env.yaml +++ b/common/env/testdata/extended_env.yaml @@ -22,15 +22,26 @@ extensions: variables: - name: "msg" type_name: "google.expr.proto3.test.TestAllTypes" + description: >- + msg represents all possible type permutation which + CEL understands from a proto perspective functions: - name: "isEmpty" + description: |- + determines whether a list is empty, + or a string has no characters overloads: - id: "wrapper_string_isEmpty" + examples: + - "''.isEmpty() // true" target: type_name: "google.protobuf.StringValue" return: type_name: "bool" - id: "list_isEmpty" + examples: + - "[].isEmpty() // true" + - "[1].isEmpty() // false" target: type_name: "list" params: diff --git a/common/env/testdata/subset_env.yaml b/common/env/testdata/subset_env.yaml index 44437e718..53adef486 100644 --- a/common/env/testdata/subset_env.yaml +++ b/common/env/testdata/subset_env.yaml @@ -33,6 +33,7 @@ stdlib: variables: - name: "x" type_name: "int" + description: - name: "y" type_name: "double" - name: "z" From 535d5615c05c9d55c7c8f608b2cf7c7e7134e83f Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Thu, 10 Apr 2025 11:06:13 -0700 Subject: [PATCH 29/46] Document the standard library macros and functions (#1159) * Support for documentation and example strings in CEL environments --- common/stdlib/BUILD.bazel | 1 + common/stdlib/standard.go | 490 +++++++++++++++++++++++++++++--------- parser/macro.go | 188 +++++++++++++-- 3 files changed, 555 insertions(+), 124 deletions(-) diff --git a/common/stdlib/BUILD.bazel b/common/stdlib/BUILD.bazel index b55f45215..124dbea81 100644 --- a/common/stdlib/BUILD.bazel +++ b/common/stdlib/BUILD.bazel @@ -12,6 +12,7 @@ go_library( ], importpath = "github.com/google/cel-go/common/stdlib", deps = [ + "//common:go_default_library", "//common/decls:go_default_library", "//common/functions:go_default_library", "//common/operators:go_default_library", diff --git a/common/stdlib/standard.go b/common/stdlib/standard.go index cbaa7d072..4040a4f5c 100644 --- a/common/stdlib/standard.go +++ b/common/stdlib/standard.go @@ -20,6 +20,7 @@ import ( "strings" "time" + "github.com/google/cel-go/common" "github.com/google/cel-go/common/decls" "github.com/google/cel-go/common/functions" "github.com/google/cel-go/common/operators" @@ -60,19 +61,49 @@ func init() { // Logical operators. Special-cased within the interpreter. // Note, the singleton binding prevents extensions from overriding the operator behavior. function(operators.Conditional, + decls.FunctionDocs( + `The ternary operator tests a boolean predicate and returns the left-hand side `+ + `(truthy) expression if true, or the right-hand side (falsy) expression if false`), decls.Overload(overloads.Conditional, argTypes(types.BoolType, paramA, paramA), paramA, - decls.OverloadIsNonStrict()), + decls.OverloadIsNonStrict(), + decls.OverloadExamples( + `'hello'.contains('lo') ? 'hi' : 'bye' // 'hi'`, + `32 % 3 == 0 ? 'divisible' : 'not divisible' // 'not divisible'`)), decls.SingletonFunctionBinding(noFunctionOverrides)), + function(operators.LogicalAnd, + decls.FunctionDocs( + `logically AND two boolean values. Errors and unknown values`, + `are valid inputs and will not halt evaluation.`), decls.Overload(overloads.LogicalAnd, argTypes(types.BoolType, types.BoolType), types.BoolType, - decls.OverloadIsNonStrict()), + decls.OverloadIsNonStrict(), + decls.OverloadExamples( + `true && true // true`, + `true && false // false`, + `error && true // error`, + `error && false // false`)), decls.SingletonBinaryBinding(noBinaryOverrides)), + function(operators.LogicalOr, + decls.FunctionDocs( + `logically OR two boolean values. Errors and unknown values`, + `are valid inputs and will not halt evaluation.`), decls.Overload(overloads.LogicalOr, argTypes(types.BoolType, types.BoolType), types.BoolType, - decls.OverloadIsNonStrict()), + decls.OverloadIsNonStrict(), + decls.OverloadExamples( + `true || false // true`, + `false || false // false`, + `error || true // true`, + `error || error // true`)), decls.SingletonBinaryBinding(noBinaryOverrides)), + function(operators.LogicalNot, - decls.Overload(overloads.LogicalNot, argTypes(types.BoolType), types.BoolType), + decls.FunctionDocs(`logically negate a boolean value.`), + decls.Overload(overloads.LogicalNot, argTypes(types.BoolType), types.BoolType, + decls.OverloadExamples( + `!true // false`, + `!false // true`, + `!error // error`)), decls.SingletonUnaryBinding(func(val ref.Val) ref.Val { b, ok := val.(types.Bool) if !ok { @@ -95,66 +126,104 @@ func init() { // Equality / inequality. Special-cased in the interpreter function(operators.Equals, - decls.Overload(overloads.Equals, argTypes(paramA, paramA), types.BoolType), + decls.FunctionDocs(`compare two values of the same type for equality`), + decls.Overload(overloads.Equals, argTypes(paramA, paramA), types.BoolType, + decls.OverloadExamples( + `1 == 1 // true`, + `'hello' == 'world' // false`, + `bytes('hello') == b'hello' // true`, + `duration('1h') == duration('60m') // true`, + `dyn(3.0) == 3 // true`)), decls.SingletonBinaryBinding(noBinaryOverrides)), function(operators.NotEquals, - decls.Overload(overloads.NotEquals, argTypes(paramA, paramA), types.BoolType), + decls.FunctionDocs(`compare two values of the same type for inequality`), + decls.Overload(overloads.NotEquals, argTypes(paramA, paramA), types.BoolType, + decls.OverloadExamples( + `1 != 2 // true`, + `"a" != "a" // false`, + `3.0 != 3.1 // true`)), decls.SingletonBinaryBinding(noBinaryOverrides)), // Mathematical operators function(operators.Add, + decls.FunctionDocs( + `adds two numeric values or concatenates two strings, bytes,`, + `or lists.`), decls.Overload(overloads.AddBytes, - argTypes(types.BytesType, types.BytesType), types.BytesType), + argTypes(types.BytesType, types.BytesType), types.BytesType, + decls.OverloadExamples(`b'hi' + bytes('ya') // b'hiya'`)), decls.Overload(overloads.AddDouble, - argTypes(types.DoubleType, types.DoubleType), types.DoubleType), + argTypes(types.DoubleType, types.DoubleType), types.DoubleType, + decls.OverloadExamples(`3.14 + 1.59 // 4.73`)), decls.Overload(overloads.AddDurationDuration, - argTypes(types.DurationType, types.DurationType), types.DurationType), + argTypes(types.DurationType, types.DurationType), types.DurationType, + decls.OverloadExamples(`duration('1m') + duration('1s') // duration('1m1s')`)), decls.Overload(overloads.AddDurationTimestamp, - argTypes(types.DurationType, types.TimestampType), types.TimestampType), + argTypes(types.DurationType, types.TimestampType), types.TimestampType, + decls.OverloadExamples(`duration('24h') + timestamp('2023-01-01T00:00:00Z') // timestamp('2023-01-02T00:00:00Z')`)), decls.Overload(overloads.AddTimestampDuration, - argTypes(types.TimestampType, types.DurationType), types.TimestampType), + argTypes(types.TimestampType, types.DurationType), types.TimestampType, + decls.OverloadExamples(`timestamp('2023-01-01T00:00:00Z') + duration('24h1m2s') // timestamp('2023-01-02T00:01:02Z')`)), decls.Overload(overloads.AddInt64, - argTypes(types.IntType, types.IntType), types.IntType), + argTypes(types.IntType, types.IntType), types.IntType, + decls.OverloadExamples(`1 + 2 // 3`)), decls.Overload(overloads.AddList, - argTypes(listOfA, listOfA), listOfA), + argTypes(listOfA, listOfA), listOfA, + decls.OverloadExamples(`[1] + [2, 3] // [1, 2, 3]`)), decls.Overload(overloads.AddString, - argTypes(types.StringType, types.StringType), types.StringType), + argTypes(types.StringType, types.StringType), types.StringType, + decls.OverloadExamples(`"Hello, " + "world!" // "Hello, world!"`)), decls.Overload(overloads.AddUint64, - argTypes(types.UintType, types.UintType), types.UintType), + argTypes(types.UintType, types.UintType), types.UintType, + decls.OverloadExamples(`22u + 33u // 55u`)), decls.SingletonBinaryBinding(func(lhs, rhs ref.Val) ref.Val { return lhs.(traits.Adder).Add(rhs) }, traits.AdderType)), function(operators.Divide, + decls.FunctionDocs(`divide two numbers`), decls.Overload(overloads.DivideDouble, - argTypes(types.DoubleType, types.DoubleType), types.DoubleType), + argTypes(types.DoubleType, types.DoubleType), types.DoubleType, + decls.OverloadExamples(`7.0 / 2.0 // 3.5`)), decls.Overload(overloads.DivideInt64, - argTypes(types.IntType, types.IntType), types.IntType), + argTypes(types.IntType, types.IntType), types.IntType, + decls.OverloadExamples(`10 / 2 // 5`)), decls.Overload(overloads.DivideUint64, - argTypes(types.UintType, types.UintType), types.UintType), + argTypes(types.UintType, types.UintType), types.UintType, + decls.OverloadExamples(`42u / 2u // 21u`)), decls.SingletonBinaryBinding(func(lhs, rhs ref.Val) ref.Val { return lhs.(traits.Divider).Divide(rhs) }, traits.DividerType)), function(operators.Modulo, + decls.FunctionDocs(`compute the modulus of one integer into another`), decls.Overload(overloads.ModuloInt64, - argTypes(types.IntType, types.IntType), types.IntType), + argTypes(types.IntType, types.IntType), types.IntType, + decls.OverloadExamples(`3 % 2 // 1`)), decls.Overload(overloads.ModuloUint64, - argTypes(types.UintType, types.UintType), types.UintType), + argTypes(types.UintType, types.UintType), types.UintType, + decls.OverloadExamples(`6u % 3u // 0u`)), decls.SingletonBinaryBinding(func(lhs, rhs ref.Val) ref.Val { return lhs.(traits.Modder).Modulo(rhs) }, traits.ModderType)), function(operators.Multiply, + decls.FunctionDocs(`multiply two numbers`), decls.Overload(overloads.MultiplyDouble, - argTypes(types.DoubleType, types.DoubleType), types.DoubleType), + argTypes(types.DoubleType, types.DoubleType), types.DoubleType, + decls.OverloadExamples(`3.5 * 40.0 // 140.0`)), decls.Overload(overloads.MultiplyInt64, - argTypes(types.IntType, types.IntType), types.IntType), + argTypes(types.IntType, types.IntType), types.IntType, + decls.OverloadExamples(`-2 * 6 // -12`)), decls.Overload(overloads.MultiplyUint64, - argTypes(types.UintType, types.UintType), types.UintType), + argTypes(types.UintType, types.UintType), types.UintType, + decls.OverloadExamples(`13u * 3u // 39u`)), decls.SingletonBinaryBinding(func(lhs, rhs ref.Val) ref.Val { return lhs.(traits.Multiplier).Multiply(rhs) }, traits.MultiplierType)), function(operators.Negate, - decls.Overload(overloads.NegateDouble, argTypes(types.DoubleType), types.DoubleType), - decls.Overload(overloads.NegateInt64, argTypes(types.IntType), types.IntType), + decls.FunctionDocs(`negate a numeric value`), + decls.Overload(overloads.NegateDouble, argTypes(types.DoubleType), types.DoubleType, + decls.OverloadExamples(`-(3.14) // -3.14`)), + decls.Overload(overloads.NegateInt64, argTypes(types.IntType), types.IntType, + decls.OverloadExamples(`-(5) // -5`)), decls.SingletonUnaryBinding(func(val ref.Val) ref.Val { if types.IsBool(val) { return types.MaybeNoSuchOverloadErr(val) @@ -162,18 +231,32 @@ func init() { return val.(traits.Negater).Negate() }, traits.NegatorType)), function(operators.Subtract, + decls.FunctionDocs(`subtract two numbers, or two time-related values`), decls.Overload(overloads.SubtractDouble, - argTypes(types.DoubleType, types.DoubleType), types.DoubleType), + argTypes(types.DoubleType, types.DoubleType), types.DoubleType, + decls.OverloadExamples(`10.5 - 2.0 // 8.5`)), decls.Overload(overloads.SubtractDurationDuration, - argTypes(types.DurationType, types.DurationType), types.DurationType), + argTypes(types.DurationType, types.DurationType), types.DurationType, + decls.OverloadExamples(`duration('1m') - duration('1s') // duration('59s')`)), decls.Overload(overloads.SubtractInt64, - argTypes(types.IntType, types.IntType), types.IntType), + argTypes(types.IntType, types.IntType), types.IntType, + decls.OverloadExamples(`5 - 3 // 2`)), decls.Overload(overloads.SubtractTimestampDuration, - argTypes(types.TimestampType, types.DurationType), types.TimestampType), + argTypes(types.TimestampType, types.DurationType), types.TimestampType, + decls.OverloadExamples(common.MultilineDescription( + `timestamp('2023-01-10T12:00:00Z')`, + ` - duration('12h') // timestamp('2023-01-10T00:00:00Z')`))), decls.Overload(overloads.SubtractTimestampTimestamp, - argTypes(types.TimestampType, types.TimestampType), types.DurationType), + argTypes(types.TimestampType, types.TimestampType), types.DurationType, + decls.OverloadExamples(common.MultilineDescription( + `timestamp('2023-01-10T12:00:00Z')`, + ` - timestamp('2023-01-10T00:00:00Z') // duration('12h')`))), decls.Overload(overloads.SubtractUint64, - argTypes(types.UintType, types.UintType), types.UintType), + argTypes(types.UintType, types.UintType), types.UintType, + decls.OverloadExamples(common.MultilineDescription( + `// the subtraction result must be positive, otherwise an overflow`, + `// error is generated.`, + `42u - 3u // 39u`))), decls.SingletonBinaryBinding(func(lhs, rhs ref.Val) ref.Val { return lhs.(traits.Subtractor).Subtract(rhs) }, traits.SubtractorType)), @@ -181,34 +264,51 @@ func init() { // Relations operators function(operators.Less, + decls.FunctionDocs( + `compare two values and return true if the first value is`, + `less than the second`), decls.Overload(overloads.LessBool, - argTypes(types.BoolType, types.BoolType), types.BoolType), + argTypes(types.BoolType, types.BoolType), types.BoolType, + decls.OverloadExamples(`false < true // true`)), decls.Overload(overloads.LessInt64, - argTypes(types.IntType, types.IntType), types.BoolType), + argTypes(types.IntType, types.IntType), types.BoolType, + decls.OverloadExamples(`-2 < 3 // true`, `1 < 0 // false`)), decls.Overload(overloads.LessInt64Double, - argTypes(types.IntType, types.DoubleType), types.BoolType), + argTypes(types.IntType, types.DoubleType), types.BoolType, + decls.OverloadExamples(`1 < 1.1 // true`)), decls.Overload(overloads.LessInt64Uint64, - argTypes(types.IntType, types.UintType), types.BoolType), + argTypes(types.IntType, types.UintType), types.BoolType, + decls.OverloadExamples(`1 < 2u // true`)), decls.Overload(overloads.LessUint64, - argTypes(types.UintType, types.UintType), types.BoolType), + argTypes(types.UintType, types.UintType), types.BoolType, + decls.OverloadExamples(`1u < 2u // true`)), decls.Overload(overloads.LessUint64Double, - argTypes(types.UintType, types.DoubleType), types.BoolType), + argTypes(types.UintType, types.DoubleType), types.BoolType, + decls.OverloadExamples(`1u < 0.9 // false`)), decls.Overload(overloads.LessUint64Int64, - argTypes(types.UintType, types.IntType), types.BoolType), + argTypes(types.UintType, types.IntType), types.BoolType, + decls.OverloadExamples(`1u < 23 // true`, `1u < -1 // false`)), decls.Overload(overloads.LessDouble, - argTypes(types.DoubleType, types.DoubleType), types.BoolType), + argTypes(types.DoubleType, types.DoubleType), types.BoolType, + decls.OverloadExamples(`2.0 < 2.4 // true`)), decls.Overload(overloads.LessDoubleInt64, - argTypes(types.DoubleType, types.IntType), types.BoolType), + argTypes(types.DoubleType, types.IntType), types.BoolType, + decls.OverloadExamples(`2.1 < 3 // true`)), decls.Overload(overloads.LessDoubleUint64, - argTypes(types.DoubleType, types.UintType), types.BoolType), + argTypes(types.DoubleType, types.UintType), types.BoolType, + decls.OverloadExamples(`2.3 < 2u // false`, `-1.0 < 1u // true`)), decls.Overload(overloads.LessString, - argTypes(types.StringType, types.StringType), types.BoolType), + argTypes(types.StringType, types.StringType), types.BoolType, + decls.OverloadExamples(`'a' < 'b' // true`, `'cat' < 'cab' // false`)), decls.Overload(overloads.LessBytes, - argTypes(types.BytesType, types.BytesType), types.BoolType), + argTypes(types.BytesType, types.BytesType), types.BoolType, + decls.OverloadExamples(`b'hello' < b'world' // true`)), decls.Overload(overloads.LessTimestamp, - argTypes(types.TimestampType, types.TimestampType), types.BoolType), + argTypes(types.TimestampType, types.TimestampType), types.BoolType, + decls.OverloadExamples(`timestamp('2001-01-01T02:03:04Z') < timestamp('2002-02-02T02:03:04Z') // true`)), decls.Overload(overloads.LessDuration, - argTypes(types.DurationType, types.DurationType), types.BoolType), + argTypes(types.DurationType, types.DurationType), types.BoolType, + decls.OverloadExamples(`duration('1ms') < duration('1s') // true`)), decls.SingletonBinaryBinding(func(lhs, rhs ref.Val) ref.Val { cmp := lhs.(traits.Comparer).Compare(rhs) if cmp == types.IntNegOne { @@ -221,34 +321,51 @@ func init() { }, traits.ComparerType)), function(operators.LessEquals, + decls.FunctionDocs( + `compare two values and return true if the first value is`, + `less than or equal to the second`), decls.Overload(overloads.LessEqualsBool, - argTypes(types.BoolType, types.BoolType), types.BoolType), + argTypes(types.BoolType, types.BoolType), types.BoolType, + decls.OverloadExamples(`false <= true // true`)), decls.Overload(overloads.LessEqualsInt64, - argTypes(types.IntType, types.IntType), types.BoolType), + argTypes(types.IntType, types.IntType), types.BoolType, + decls.OverloadExamples(`-2 <= 3 // true`)), decls.Overload(overloads.LessEqualsInt64Double, - argTypes(types.IntType, types.DoubleType), types.BoolType), + argTypes(types.IntType, types.DoubleType), types.BoolType, + decls.OverloadExamples(`1 <= 1.1 // true`)), decls.Overload(overloads.LessEqualsInt64Uint64, - argTypes(types.IntType, types.UintType), types.BoolType), + argTypes(types.IntType, types.UintType), types.BoolType, + decls.OverloadExamples(`1 <= 2u // true`, `-1 <= 0u // true`)), decls.Overload(overloads.LessEqualsUint64, - argTypes(types.UintType, types.UintType), types.BoolType), + argTypes(types.UintType, types.UintType), types.BoolType, + decls.OverloadExamples(`1u <= 2u // true`)), decls.Overload(overloads.LessEqualsUint64Double, - argTypes(types.UintType, types.DoubleType), types.BoolType), + argTypes(types.UintType, types.DoubleType), types.BoolType, + decls.OverloadExamples(`1u <= 1.0 // true`, `1u <= 1.1 // true`)), decls.Overload(overloads.LessEqualsUint64Int64, - argTypes(types.UintType, types.IntType), types.BoolType), + argTypes(types.UintType, types.IntType), types.BoolType, + decls.OverloadExamples(`1u <= 23 // true`)), decls.Overload(overloads.LessEqualsDouble, - argTypes(types.DoubleType, types.DoubleType), types.BoolType), + argTypes(types.DoubleType, types.DoubleType), types.BoolType, + decls.OverloadExamples(`2.0 <= 2.4 // true`)), decls.Overload(overloads.LessEqualsDoubleInt64, - argTypes(types.DoubleType, types.IntType), types.BoolType), + argTypes(types.DoubleType, types.IntType), types.BoolType, + decls.OverloadExamples(`2.1 <= 3 // true`)), decls.Overload(overloads.LessEqualsDoubleUint64, - argTypes(types.DoubleType, types.UintType), types.BoolType), + argTypes(types.DoubleType, types.UintType), types.BoolType, + decls.OverloadExamples(`2.0 <= 2u // true`, `-1.0 <= 1u // true`)), decls.Overload(overloads.LessEqualsString, - argTypes(types.StringType, types.StringType), types.BoolType), + argTypes(types.StringType, types.StringType), types.BoolType, + decls.OverloadExamples(`'a' <= 'b' // true`, `'a' <= 'a' // true`, `'cat' <= 'cab' // false`)), decls.Overload(overloads.LessEqualsBytes, - argTypes(types.BytesType, types.BytesType), types.BoolType), + argTypes(types.BytesType, types.BytesType), types.BoolType, + decls.OverloadExamples(`b'hello' <= b'world' // true`)), decls.Overload(overloads.LessEqualsTimestamp, - argTypes(types.TimestampType, types.TimestampType), types.BoolType), + argTypes(types.TimestampType, types.TimestampType), types.BoolType, + decls.OverloadExamples(`timestamp('2001-01-01T02:03:04Z') <= timestamp('2002-02-02T02:03:04Z') // true`)), decls.Overload(overloads.LessEqualsDuration, - argTypes(types.DurationType, types.DurationType), types.BoolType), + argTypes(types.DurationType, types.DurationType), types.BoolType, + decls.OverloadExamples(`duration('1ms') <= duration('1s') // true`)), decls.SingletonBinaryBinding(func(lhs, rhs ref.Val) ref.Val { cmp := lhs.(traits.Comparer).Compare(rhs) if cmp == types.IntNegOne || cmp == types.IntZero { @@ -261,34 +378,51 @@ func init() { }, traits.ComparerType)), function(operators.Greater, + decls.FunctionDocs( + `compare two values and return true if the first value is`, + `greater than the second`), decls.Overload(overloads.GreaterBool, - argTypes(types.BoolType, types.BoolType), types.BoolType), + argTypes(types.BoolType, types.BoolType), types.BoolType, + decls.OverloadExamples(`true > false // true`)), decls.Overload(overloads.GreaterInt64, - argTypes(types.IntType, types.IntType), types.BoolType), + argTypes(types.IntType, types.IntType), types.BoolType, + decls.OverloadExamples(`3 > -2 // true`)), decls.Overload(overloads.GreaterInt64Double, - argTypes(types.IntType, types.DoubleType), types.BoolType), + argTypes(types.IntType, types.DoubleType), types.BoolType, + decls.OverloadExamples(`2 > 1.1 // true`)), decls.Overload(overloads.GreaterInt64Uint64, - argTypes(types.IntType, types.UintType), types.BoolType), + argTypes(types.IntType, types.UintType), types.BoolType, + decls.OverloadExamples(`3 > 2u // true`)), decls.Overload(overloads.GreaterUint64, - argTypes(types.UintType, types.UintType), types.BoolType), + argTypes(types.UintType, types.UintType), types.BoolType, + decls.OverloadExamples(`2u > 1u // true`)), decls.Overload(overloads.GreaterUint64Double, - argTypes(types.UintType, types.DoubleType), types.BoolType), + argTypes(types.UintType, types.DoubleType), types.BoolType, + decls.OverloadExamples(`2u > 1.9 // true`)), decls.Overload(overloads.GreaterUint64Int64, - argTypes(types.UintType, types.IntType), types.BoolType), + argTypes(types.UintType, types.IntType), types.BoolType, + decls.OverloadExamples(`23u > 1 // true`, `0u > -1 // true`)), decls.Overload(overloads.GreaterDouble, - argTypes(types.DoubleType, types.DoubleType), types.BoolType), + argTypes(types.DoubleType, types.DoubleType), types.BoolType, + decls.OverloadExamples(`2.4 > 2.0 // true`)), decls.Overload(overloads.GreaterDoubleInt64, - argTypes(types.DoubleType, types.IntType), types.BoolType), + argTypes(types.DoubleType, types.IntType), types.BoolType, + decls.OverloadExamples(`3.1 > 3 // true`, `3.0 > 3 // false`)), decls.Overload(overloads.GreaterDoubleUint64, - argTypes(types.DoubleType, types.UintType), types.BoolType), + argTypes(types.DoubleType, types.UintType), types.BoolType, + decls.OverloadExamples(`2.3 > 2u // true`)), decls.Overload(overloads.GreaterString, - argTypes(types.StringType, types.StringType), types.BoolType), + argTypes(types.StringType, types.StringType), types.BoolType, + decls.OverloadExamples(`'b' > 'a' // true`)), decls.Overload(overloads.GreaterBytes, - argTypes(types.BytesType, types.BytesType), types.BoolType), + argTypes(types.BytesType, types.BytesType), types.BoolType, + decls.OverloadExamples(`b'world' > b'hello' // true`)), decls.Overload(overloads.GreaterTimestamp, - argTypes(types.TimestampType, types.TimestampType), types.BoolType), + argTypes(types.TimestampType, types.TimestampType), types.BoolType, + decls.OverloadExamples(`timestamp('2002-02-02T02:03:04Z') > timestamp('2001-01-01T02:03:04Z') // true`)), decls.Overload(overloads.GreaterDuration, - argTypes(types.DurationType, types.DurationType), types.BoolType), + argTypes(types.DurationType, types.DurationType), types.BoolType, + decls.OverloadExamples(`duration('1ms') > duration('1us') // true`)), decls.SingletonBinaryBinding(func(lhs, rhs ref.Val) ref.Val { cmp := lhs.(traits.Comparer).Compare(rhs) if cmp == types.IntOne { @@ -301,34 +435,51 @@ func init() { }, traits.ComparerType)), function(operators.GreaterEquals, + decls.FunctionDocs( + `compare two values and return true if the first value is`, + `greater than or equal to the second`), decls.Overload(overloads.GreaterEqualsBool, - argTypes(types.BoolType, types.BoolType), types.BoolType), + argTypes(types.BoolType, types.BoolType), types.BoolType, + decls.OverloadExamples(`true >= false // true`)), decls.Overload(overloads.GreaterEqualsInt64, - argTypes(types.IntType, types.IntType), types.BoolType), + argTypes(types.IntType, types.IntType), types.BoolType, + decls.OverloadExamples(`3 >= -2 // true`)), decls.Overload(overloads.GreaterEqualsInt64Double, - argTypes(types.IntType, types.DoubleType), types.BoolType), + argTypes(types.IntType, types.DoubleType), types.BoolType, + decls.OverloadExamples(`2 >= 1.1 // true`, `1 >= 1.0 // true`)), decls.Overload(overloads.GreaterEqualsInt64Uint64, - argTypes(types.IntType, types.UintType), types.BoolType), + argTypes(types.IntType, types.UintType), types.BoolType, + decls.OverloadExamples(`3 >= 2u // true`)), decls.Overload(overloads.GreaterEqualsUint64, - argTypes(types.UintType, types.UintType), types.BoolType), + argTypes(types.UintType, types.UintType), types.BoolType, + decls.OverloadExamples(`2u >= 1u // true`)), decls.Overload(overloads.GreaterEqualsUint64Double, - argTypes(types.UintType, types.DoubleType), types.BoolType), + argTypes(types.UintType, types.DoubleType), types.BoolType, + decls.OverloadExamples(`2u >= 1.9 // true`)), decls.Overload(overloads.GreaterEqualsUint64Int64, - argTypes(types.UintType, types.IntType), types.BoolType), + argTypes(types.UintType, types.IntType), types.BoolType, + decls.OverloadExamples(`23u >= 1 // true`, `1u >= 1 // true`)), decls.Overload(overloads.GreaterEqualsDouble, - argTypes(types.DoubleType, types.DoubleType), types.BoolType), + argTypes(types.DoubleType, types.DoubleType), types.BoolType, + decls.OverloadExamples(`2.4 >= 2.0 // true`)), decls.Overload(overloads.GreaterEqualsDoubleInt64, - argTypes(types.DoubleType, types.IntType), types.BoolType), + argTypes(types.DoubleType, types.IntType), types.BoolType, + decls.OverloadExamples(`3.1 >= 3 // true`)), decls.Overload(overloads.GreaterEqualsDoubleUint64, - argTypes(types.DoubleType, types.UintType), types.BoolType), + argTypes(types.DoubleType, types.UintType), types.BoolType, + decls.OverloadExamples(`2.3 >= 2u // true`)), decls.Overload(overloads.GreaterEqualsString, - argTypes(types.StringType, types.StringType), types.BoolType), + argTypes(types.StringType, types.StringType), types.BoolType, + decls.OverloadExamples(`'b' >= 'a' // true`)), decls.Overload(overloads.GreaterEqualsBytes, - argTypes(types.BytesType, types.BytesType), types.BoolType), + argTypes(types.BytesType, types.BytesType), types.BoolType, + decls.OverloadExamples(`b'world' >= b'hello' // true`)), decls.Overload(overloads.GreaterEqualsTimestamp, - argTypes(types.TimestampType, types.TimestampType), types.BoolType), + argTypes(types.TimestampType, types.TimestampType), types.BoolType, + decls.OverloadExamples(`timestamp('2001-01-01T02:03:04Z') >= timestamp('2001-01-01T02:03:04Z') // true`)), decls.Overload(overloads.GreaterEqualsDuration, - argTypes(types.DurationType, types.DurationType), types.BoolType), + argTypes(types.DurationType, types.DurationType), types.BoolType, + decls.OverloadExamples(`duration('60s') >= duration('1m') // true`)), decls.SingletonBinaryBinding(func(lhs, rhs ref.Val) ref.Val { cmp := lhs.(traits.Comparer).Compare(rhs) if cmp == types.IntOne || cmp == types.IntZero { @@ -342,16 +493,28 @@ func init() { // Indexing function(operators.Index, - decls.Overload(overloads.IndexList, argTypes(listOfA, types.IntType), paramA), - decls.Overload(overloads.IndexMap, argTypes(mapOfAB, paramA), paramB), + decls.FunctionDocs(`select a value from a list by index, or value from a map by key`), + decls.Overload(overloads.IndexList, argTypes(listOfA, types.IntType), paramA, + decls.OverloadExamples(`[1, 2, 3][1] // 2`)), + decls.Overload(overloads.IndexMap, argTypes(mapOfAB, paramA), paramB, + decls.OverloadExamples( + `{'key': 'value'}['key'] // 'value'`, + `{'key': 'value'}['missing'] // error`)), decls.SingletonBinaryBinding(func(lhs, rhs ref.Val) ref.Val { return lhs.(traits.Indexer).Get(rhs) }, traits.IndexerType)), // Collections operators function(operators.In, - decls.Overload(overloads.InList, argTypes(paramA, listOfA), types.BoolType), - decls.Overload(overloads.InMap, argTypes(paramA, mapOfAB), types.BoolType), + decls.FunctionDocs(`test whether a value exists in a list, or a key exists in a map`), + decls.Overload(overloads.InList, argTypes(paramA, listOfA), types.BoolType, + decls.OverloadExamples( + `2 in [1, 2, 3] // true`, + `"a" in ["b", "c"] // false`)), + decls.Overload(overloads.InMap, argTypes(paramA, mapOfAB), types.BoolType, + decls.OverloadExamples( + `'key1' in {'key1': 'value1', 'key2': 'value2'} // true`, + `3 in {1: "one", 2: "two"} // false`)), decls.SingletonBinaryBinding(inAggregate)), function(operators.OldIn, decls.DisableDeclaration(true), // safe deprecation @@ -364,249 +527,360 @@ func init() { decls.Overload(overloads.InMap, argTypes(paramA, mapOfAB), types.BoolType), decls.SingletonBinaryBinding(inAggregate)), function(overloads.Size, - decls.Overload(overloads.SizeBytes, argTypes(types.BytesType), types.IntType), - decls.MemberOverload(overloads.SizeBytesInst, argTypes(types.BytesType), types.IntType), - decls.Overload(overloads.SizeList, argTypes(listOfA), types.IntType), - decls.MemberOverload(overloads.SizeListInst, argTypes(listOfA), types.IntType), - decls.Overload(overloads.SizeMap, argTypes(mapOfAB), types.IntType), - decls.MemberOverload(overloads.SizeMapInst, argTypes(mapOfAB), types.IntType), - decls.Overload(overloads.SizeString, argTypes(types.StringType), types.IntType), - decls.MemberOverload(overloads.SizeStringInst, argTypes(types.StringType), types.IntType), + decls.FunctionDocs( + `compute the size of a list or map, the number of characters in a string,`, + `or the number of bytes in a sequence`), + decls.Overload(overloads.SizeBytes, argTypes(types.BytesType), types.IntType, + decls.OverloadExamples(`size(b'123') // 3`)), + decls.MemberOverload(overloads.SizeBytesInst, argTypes(types.BytesType), types.IntType, + decls.OverloadExamples(`b'123'.size() // 3`)), + decls.Overload(overloads.SizeList, argTypes(listOfA), types.IntType, + decls.OverloadExamples(`size([1, 2, 3]) // 3`)), + decls.MemberOverload(overloads.SizeListInst, argTypes(listOfA), types.IntType, + decls.OverloadExamples(`[1, 2, 3].size() // 3`)), + decls.Overload(overloads.SizeMap, argTypes(mapOfAB), types.IntType, + decls.OverloadExamples(`size({'a': 1, 'b': 2}) // 2`)), + decls.MemberOverload(overloads.SizeMapInst, argTypes(mapOfAB), types.IntType, + decls.OverloadExamples(`{'a': 1, 'b': 2}.size() // 2`)), + decls.Overload(overloads.SizeString, argTypes(types.StringType), types.IntType, + decls.OverloadExamples(`size('hello') // 5`)), + decls.MemberOverload(overloads.SizeStringInst, argTypes(types.StringType), types.IntType, + decls.OverloadExamples(`'hello'.size() // 5`)), decls.SingletonUnaryBinding(func(val ref.Val) ref.Val { return val.(traits.Sizer).Size() }, traits.SizerType)), // Type conversions function(overloads.TypeConvertType, - decls.Overload(overloads.TypeConvertType, argTypes(paramA), types.NewTypeTypeWithParam(paramA)), + decls.FunctionDocs(`convert a value to its type identifier`), + decls.Overload(overloads.TypeConvertType, argTypes(paramA), types.NewTypeTypeWithParam(paramA), + decls.OverloadExamples( + `type(1) // int`, + `type('hello') // string`, + `type(int) // type`, + `type(type) // type`)), decls.SingletonUnaryBinding(convertToType(types.TypeType))), // Bool conversions function(overloads.TypeConvertBool, + decls.FunctionDocs(`convert a value to a boolean`), decls.Overload(overloads.BoolToBool, argTypes(types.BoolType), types.BoolType, + + decls.OverloadExamples(`bool(true) // true`), decls.UnaryBinding(identity)), decls.Overload(overloads.StringToBool, argTypes(types.StringType), types.BoolType, + + decls.OverloadExamples(`bool('true') // true`, `bool('false') // false`), decls.UnaryBinding(convertToType(types.BoolType)))), // Bytes conversions function(overloads.TypeConvertBytes, + decls.FunctionDocs(`convert a value to bytes`), decls.Overload(overloads.BytesToBytes, argTypes(types.BytesType), types.BytesType, + decls.OverloadExamples(`bytes(b'abc') // b'abc'`), decls.UnaryBinding(identity)), decls.Overload(overloads.StringToBytes, argTypes(types.StringType), types.BytesType, + decls.OverloadExamples(`bytes('hello') // b'hello'`), decls.UnaryBinding(convertToType(types.BytesType)))), // Double conversions function(overloads.TypeConvertDouble, + decls.FunctionDocs(`convert a value to a double`), decls.Overload(overloads.DoubleToDouble, argTypes(types.DoubleType), types.DoubleType, + decls.OverloadExamples(`double(1.23) // 1.23`), decls.UnaryBinding(identity)), decls.Overload(overloads.IntToDouble, argTypes(types.IntType), types.DoubleType, + decls.OverloadExamples(`double(123) // 123.0`), decls.UnaryBinding(convertToType(types.DoubleType))), decls.Overload(overloads.StringToDouble, argTypes(types.StringType), types.DoubleType, + decls.OverloadExamples(`double('1.23') // 1.23`), decls.UnaryBinding(convertToType(types.DoubleType))), decls.Overload(overloads.UintToDouble, argTypes(types.UintType), types.DoubleType, + decls.OverloadExamples(`double(123u) // 123.0`), decls.UnaryBinding(convertToType(types.DoubleType)))), // Duration conversions function(overloads.TypeConvertDuration, + decls.FunctionDocs(`convert a value to a google.protobuf.Duration`), decls.Overload(overloads.DurationToDuration, argTypes(types.DurationType), types.DurationType, + decls.OverloadExamples(`duration(duration('1s')) // duration('1s')`), decls.UnaryBinding(identity)), decls.Overload(overloads.IntToDuration, argTypes(types.IntType), types.DurationType, decls.UnaryBinding(convertToType(types.DurationType))), decls.Overload(overloads.StringToDuration, argTypes(types.StringType), types.DurationType, + decls.OverloadExamples(`duration('1h2m3s') // duration('3723s')`), decls.UnaryBinding(convertToType(types.DurationType)))), // Dyn conversions function(overloads.TypeConvertDyn, - decls.Overload(overloads.ToDyn, argTypes(paramA), types.DynType), + decls.FunctionDocs(`indicate that the type is dynamic for type-checking purposes`), + decls.Overload(overloads.ToDyn, argTypes(paramA), types.DynType, + decls.OverloadExamples(`dyn(1) // 1`)), decls.SingletonUnaryBinding(identity)), // Int conversions function(overloads.TypeConvertInt, + decls.FunctionDocs(`convert a value to an int`), decls.Overload(overloads.IntToInt, argTypes(types.IntType), types.IntType, + decls.OverloadExamples(`int(123) // 123`), decls.UnaryBinding(identity)), decls.Overload(overloads.DoubleToInt, argTypes(types.DoubleType), types.IntType, + decls.OverloadExamples(`int(123.45) // 123`), decls.UnaryBinding(convertToType(types.IntType))), decls.Overload(overloads.DurationToInt, argTypes(types.DurationType), types.IntType, - decls.UnaryBinding(convertToType(types.IntType))), + decls.OverloadExamples(`int(duration('1s')) // 1000000000`), + decls.UnaryBinding(convertToType(types.IntType))), // Duration to nanoseconds decls.Overload(overloads.StringToInt, argTypes(types.StringType), types.IntType, + decls.OverloadExamples(`int('123') // 123`, `int('-456') // -456`), decls.UnaryBinding(convertToType(types.IntType))), decls.Overload(overloads.TimestampToInt, argTypes(types.TimestampType), types.IntType, - decls.UnaryBinding(convertToType(types.IntType))), + decls.OverloadExamples(`int(timestamp('1970-01-01T00:00:01Z')) // 1`), + decls.UnaryBinding(convertToType(types.IntType))), // Timestamp to epoch seconds decls.Overload(overloads.UintToInt, argTypes(types.UintType), types.IntType, - decls.UnaryBinding(convertToType(types.IntType))), - ), + decls.OverloadExamples(`int(123u) // 123`), + decls.UnaryBinding(convertToType(types.IntType)))), // String conversions function(overloads.TypeConvertString, + decls.FunctionDocs(`convert a value to a string`), decls.Overload(overloads.StringToString, argTypes(types.StringType), types.StringType, + decls.OverloadExamples(`string('hello') // 'hello'`), decls.UnaryBinding(identity)), decls.Overload(overloads.BoolToString, argTypes(types.BoolType), types.StringType, + decls.OverloadExamples(`string(true) // 'true'`), decls.UnaryBinding(convertToType(types.StringType))), decls.Overload(overloads.BytesToString, argTypes(types.BytesType), types.StringType, + decls.OverloadExamples(`string(b'hello') // 'hello'`), decls.UnaryBinding(convertToType(types.StringType))), decls.Overload(overloads.DoubleToString, argTypes(types.DoubleType), types.StringType, - decls.UnaryBinding(convertToType(types.StringType))), + decls.UnaryBinding(convertToType(types.StringType)), + decls.OverloadExamples(`string(-1.23e4) // '-12300'`)), decls.Overload(overloads.DurationToString, argTypes(types.DurationType), types.StringType, + decls.OverloadExamples(`string(duration('1h30m')) // '5400s'`), decls.UnaryBinding(convertToType(types.StringType))), decls.Overload(overloads.IntToString, argTypes(types.IntType), types.StringType, + decls.OverloadExamples(`string(-123) // '-123'`), decls.UnaryBinding(convertToType(types.StringType))), decls.Overload(overloads.TimestampToString, argTypes(types.TimestampType), types.StringType, + decls.OverloadExamples(`string(timestamp('1970-01-01T00:00:00Z')) // '1970-01-01T00:00:00Z'`), decls.UnaryBinding(convertToType(types.StringType))), decls.Overload(overloads.UintToString, argTypes(types.UintType), types.StringType, + decls.OverloadExamples(`string(123u) // '123'`), decls.UnaryBinding(convertToType(types.StringType)))), // Timestamp conversions function(overloads.TypeConvertTimestamp, + decls.FunctionDocs(`convert a value to a google.protobuf.Timestamp`), decls.Overload(overloads.TimestampToTimestamp, argTypes(types.TimestampType), types.TimestampType, + decls.OverloadExamples(`timestamp(timestamp('2023-01-01T00:00:00Z')) // timestamp('2023-01-01T00:00:00Z')`), decls.UnaryBinding(identity)), decls.Overload(overloads.IntToTimestamp, argTypes(types.IntType), types.TimestampType, + decls.OverloadExamples(`timestamp(1) // timestamp('1970-01-01T00:00:01Z')`), // Epoch seconds to Timestamp decls.UnaryBinding(convertToType(types.TimestampType))), decls.Overload(overloads.StringToTimestamp, argTypes(types.StringType), types.TimestampType, + decls.OverloadExamples(`timestamp('2025-01-01T12:34:56Z') // timestamp('2025-01-01T12:34:56Z')`), decls.UnaryBinding(convertToType(types.TimestampType)))), // Uint conversions function(overloads.TypeConvertUint, + decls.FunctionDocs(`convert a value to a uint`), decls.Overload(overloads.UintToUint, argTypes(types.UintType), types.UintType, + decls.OverloadExamples(`uint(123u) // 123u`), decls.UnaryBinding(identity)), decls.Overload(overloads.DoubleToUint, argTypes(types.DoubleType), types.UintType, + decls.OverloadExamples(`uint(123.45) // 123u`), decls.UnaryBinding(convertToType(types.UintType))), decls.Overload(overloads.IntToUint, argTypes(types.IntType), types.UintType, + decls.OverloadExamples(`uint(123) // 123u`), decls.UnaryBinding(convertToType(types.UintType))), decls.Overload(overloads.StringToUint, argTypes(types.StringType), types.UintType, + decls.OverloadExamples(`uint('123') // 123u`), decls.UnaryBinding(convertToType(types.UintType)))), // String functions function(overloads.Contains, + decls.FunctionDocs(`test whether a string contains a substring`), decls.MemberOverload(overloads.ContainsString, argTypes(types.StringType, types.StringType), types.BoolType, + decls.OverloadExamples( + `'hello world'.contains('o w') // true`, + `'hello world'.contains('goodbye') // false`), decls.BinaryBinding(types.StringContains)), decls.DisableTypeGuards(true)), function(overloads.EndsWith, + decls.FunctionDocs(`test whether a string ends with a substring suffix`), decls.MemberOverload(overloads.EndsWithString, argTypes(types.StringType, types.StringType), types.BoolType, + decls.OverloadExamples( + `'hello world'.endsWith('world') // true`, + `'hello world'.endsWith('hello') // false`), decls.BinaryBinding(types.StringEndsWith)), decls.DisableTypeGuards(true)), function(overloads.StartsWith, + decls.FunctionDocs(`test whether a string starts with a substring prefix`), decls.MemberOverload(overloads.StartsWithString, argTypes(types.StringType, types.StringType), types.BoolType, + decls.OverloadExamples( + `'hello world'.startsWith('hello') // true`, + `'hello world'.startsWith('world') // false`), decls.BinaryBinding(types.StringStartsWith)), decls.DisableTypeGuards(true)), function(overloads.Matches, - decls.Overload(overloads.Matches, argTypes(types.StringType, types.StringType), types.BoolType), + decls.FunctionDocs(`test whether a string matches an RE2 regular expression`), + decls.Overload(overloads.Matches, argTypes(types.StringType, types.StringType), types.BoolType, + decls.OverloadExamples( + `matches('123-456', '^[0-9]+(-[0-9]+)?$') // true`, + `matches('hello', '^h.*o$') // true`)), decls.MemberOverload(overloads.MatchesString, - argTypes(types.StringType, types.StringType), types.BoolType), + argTypes(types.StringType, types.StringType), types.BoolType, + decls.OverloadExamples( + `'123-456'.matches('^[0-9]+(-[0-9]+)?$') // true`, + `'hello'.matches('^h.*o$') // true`)), decls.SingletonBinaryBinding(func(str, pat ref.Val) ref.Val { return str.(traits.Matcher).Match(pat) }, traits.MatcherType)), // Timestamp / duration functions function(overloads.TimeGetFullYear, + decls.FunctionDocs(`get the 0-based full year from a timestamp, UTC unless an IANA timezone is specified.`), decls.MemberOverload(overloads.TimestampToYear, argTypes(types.TimestampType), types.IntType, + decls.OverloadExamples(`timestamp('2023-07-14T10:30:45.123Z').getFullYear() // 2023`), decls.UnaryBinding(func(ts ref.Val) ref.Val { return timestampGetFullYear(ts, utcTZ) })), decls.MemberOverload(overloads.TimestampToYearWithTz, argTypes(types.TimestampType, types.StringType), types.IntType, + decls.OverloadExamples(`timestamp('2023-01-01T05:30:00Z').getFullYear('-08:00') // 2022`), decls.BinaryBinding(timestampGetFullYear))), function(overloads.TimeGetMonth, + decls.FunctionDocs(`get the 0-based month from a timestamp, UTC unless an IANA timezone is specified.`), decls.MemberOverload(overloads.TimestampToMonth, argTypes(types.TimestampType), types.IntType, + decls.OverloadExamples(`timestamp('2023-07-14T10:30:45.123Z').getMonth() // 6`), // July is month 6 decls.UnaryBinding(func(ts ref.Val) ref.Val { return timestampGetMonth(ts, utcTZ) })), decls.MemberOverload(overloads.TimestampToMonthWithTz, argTypes(types.TimestampType, types.StringType), types.IntType, + decls.OverloadExamples(`timestamp('2023-01-01T05:30:00Z').getMonth('America/Los_Angeles') // 11`), // December is month 11 decls.BinaryBinding(timestampGetMonth))), function(overloads.TimeGetDayOfYear, + decls.FunctionDocs(`get the 0-based day of the year from a timestamp, UTC unless an IANA timezone is specified.`), decls.MemberOverload(overloads.TimestampToDayOfYear, argTypes(types.TimestampType), types.IntType, + decls.OverloadExamples(`timestamp('2023-01-02T00:00:00Z').getDayOfYear() // 1`), decls.UnaryBinding(func(ts ref.Val) ref.Val { return timestampGetDayOfYear(ts, utcTZ) })), decls.MemberOverload(overloads.TimestampToDayOfYearWithTz, argTypes(types.TimestampType, types.StringType), types.IntType, + decls.OverloadExamples(`timestamp('2023-01-01T05:00:00Z').getDayOfYear('America/Los_Angeles') // 364`), decls.BinaryBinding(timestampGetDayOfYear))), function(overloads.TimeGetDayOfMonth, + decls.FunctionDocs(`get the 0-based day of the month from a timestamp, UTC unless an IANA timezone is specified.`), decls.MemberOverload(overloads.TimestampToDayOfMonthZeroBased, argTypes(types.TimestampType), types.IntType, + decls.OverloadExamples(`timestamp('2023-07-14T10:30:45.123Z').getDayOfMonth() // 13`), decls.UnaryBinding(func(ts ref.Val) ref.Val { return timestampGetDayOfMonthZeroBased(ts, utcTZ) })), decls.MemberOverload(overloads.TimestampToDayOfMonthZeroBasedWithTz, argTypes(types.TimestampType, types.StringType), types.IntType, + decls.OverloadExamples(`timestamp('2023-07-01T05:00:00Z').getDayOfMonth('America/Los_Angeles') // 29`), decls.BinaryBinding(timestampGetDayOfMonthZeroBased))), function(overloads.TimeGetDate, + decls.FunctionDocs(`get the 1-based day of the month from a timestamp, UTC unless an IANA timezone is specified.`), decls.MemberOverload(overloads.TimestampToDayOfMonthOneBased, argTypes(types.TimestampType), types.IntType, + decls.OverloadExamples(`timestamp('2023-07-14T10:30:45.123Z').getDate() // 14`), decls.UnaryBinding(func(ts ref.Val) ref.Val { return timestampGetDayOfMonthOneBased(ts, utcTZ) })), decls.MemberOverload(overloads.TimestampToDayOfMonthOneBasedWithTz, argTypes(types.TimestampType, types.StringType), types.IntType, + decls.OverloadExamples(`timestamp('2023-07-01T05:00:00Z').getDate('America/Los_Angeles') // 30`), decls.BinaryBinding(timestampGetDayOfMonthOneBased))), function(overloads.TimeGetDayOfWeek, + decls.FunctionDocs(`get the 0-based day of the week from a timestamp, UTC unless an IANA timezone is specified.`), decls.MemberOverload(overloads.TimestampToDayOfWeek, argTypes(types.TimestampType), types.IntType, + decls.OverloadExamples(`timestamp('2023-07-14T10:30:45.123Z').getDayOfWeek() // 5`), // Friday is day 5 decls.UnaryBinding(func(ts ref.Val) ref.Val { return timestampGetDayOfWeek(ts, utcTZ) })), decls.MemberOverload(overloads.TimestampToDayOfWeekWithTz, argTypes(types.TimestampType, types.StringType), types.IntType, + decls.OverloadExamples(`timestamp('2023-07-16T05:00:00Z').getDayOfWeek('America/Los_Angeles') // 6`), // Saturday is day 6 decls.BinaryBinding(timestampGetDayOfWeek))), function(overloads.TimeGetHours, + decls.FunctionDocs(`get the hours portion from a timestamp, or convert a duration to hours`), decls.MemberOverload(overloads.TimestampToHours, argTypes(types.TimestampType), types.IntType, + decls.OverloadExamples(`timestamp('2023-07-14T10:30:45.123Z').getHours() // 10`), decls.UnaryBinding(func(ts ref.Val) ref.Val { return timestampGetHours(ts, utcTZ) })), decls.MemberOverload(overloads.TimestampToHoursWithTz, argTypes(types.TimestampType, types.StringType), types.IntType, + decls.OverloadExamples(`timestamp('2023-07-14T10:30:45.123Z').getHours('America/Los_Angeles') // 2`), decls.BinaryBinding(timestampGetHours)), decls.MemberOverload(overloads.DurationToHours, argTypes(types.DurationType), types.IntType, + decls.OverloadExamples(`duration('3723s').getHours() // 1`), decls.UnaryBinding(types.DurationGetHours))), function(overloads.TimeGetMinutes, + decls.FunctionDocs(`get the minutes portion from a timestamp, or convert a duration to minutes`), decls.MemberOverload(overloads.TimestampToMinutes, argTypes(types.TimestampType), types.IntType, + decls.OverloadExamples(`timestamp('2023-07-14T10:30:45.123Z').getMinutes() // 30`), decls.UnaryBinding(func(ts ref.Val) ref.Val { return timestampGetMinutes(ts, utcTZ) })), decls.MemberOverload(overloads.TimestampToMinutesWithTz, argTypes(types.TimestampType, types.StringType), types.IntType, + decls.OverloadExamples(`timestamp('2023-07-14T10:30:45.123Z').getMinutes('America/Los_Angeles') // 30`), decls.BinaryBinding(timestampGetMinutes)), decls.MemberOverload(overloads.DurationToMinutes, argTypes(types.DurationType), types.IntType, + decls.OverloadExamples(`duration('3723s').getMinutes() // 62`), decls.UnaryBinding(types.DurationGetMinutes))), function(overloads.TimeGetSeconds, + decls.FunctionDocs(`get the seconds portion from a timestamp, or convert a duration to seconds`), decls.MemberOverload(overloads.TimestampToSeconds, argTypes(types.TimestampType), types.IntType, + decls.OverloadExamples(`timestamp('2023-07-14T10:30:45.123Z').getSeconds() // 45`), decls.UnaryBinding(func(ts ref.Val) ref.Val { return timestampGetSeconds(ts, utcTZ) })), decls.MemberOverload(overloads.TimestampToSecondsWithTz, argTypes(types.TimestampType, types.StringType), types.IntType, + decls.OverloadExamples(`timestamp('2023-07-14T10:30:45.123Z').getSeconds('America/Los_Angeles') // 45`), decls.BinaryBinding(timestampGetSeconds)), decls.MemberOverload(overloads.DurationToSeconds, argTypes(types.DurationType), types.IntType, + decls.OverloadExamples(`duration('3723.456s').getSeconds() // 3723`), decls.UnaryBinding(types.DurationGetSeconds))), function(overloads.TimeGetMilliseconds, + decls.FunctionDocs(`get the milliseconds portion from a timestamp`), decls.MemberOverload(overloads.TimestampToMilliseconds, argTypes(types.TimestampType), types.IntType, + decls.OverloadExamples(`timestamp('2023-07-14T10:30:45.123Z').getMilliseconds() // 123`), decls.UnaryBinding(func(ts ref.Val) ref.Val { return timestampGetMilliseconds(ts, utcTZ) })), decls.MemberOverload(overloads.TimestampToMillisecondsWithTz, argTypes(types.TimestampType, types.StringType), types.IntType, + decls.OverloadExamples(`timestamp('2023-07-14T10:30:45.123Z').getMilliseconds('America/Los_Angeles') // 123`), decls.BinaryBinding(timestampGetMilliseconds)), decls.MemberOverload(overloads.DurationToMilliseconds, argTypes(types.DurationType), types.IntType, diff --git a/parser/macro.go b/parser/macro.go index 6b3b648d3..7661755e2 100644 --- a/parser/macro.go +++ b/parser/macro.go @@ -24,38 +24,74 @@ import ( "github.com/google/cel-go/common/types/ref" ) +// MacroOpt defines a functional option for configuring macro behavior. +type MacroOpt func(*macro) *macro + +// MacroDocs configures a list of strings into a multiline description for the macro. +func MacroDocs(docs ...string) MacroOpt { + return func(m *macro) *macro { + m.doc = common.MultilineDescription(docs...) + return m + } +} + +// MacroExamples configures a list of examples, either as a string or common.MultilineString, +// into an example set to be provided with the macro Documentation() call. +func MacroExamples(examples ...string) MacroOpt { + return func(m *macro) *macro { + m.examples = examples + return m + } +} + // NewGlobalMacro creates a Macro for a global function with the specified arg count. -func NewGlobalMacro(function string, argCount int, expander MacroExpander) Macro { - return ¯o{ +func NewGlobalMacro(function string, argCount int, expander MacroExpander, opts ...MacroOpt) Macro { + m := ¯o{ function: function, argCount: argCount, expander: expander} + for _, opt := range opts { + m = opt(m) + } + return m } // NewReceiverMacro creates a Macro for a receiver function matching the specified arg count. -func NewReceiverMacro(function string, argCount int, expander MacroExpander) Macro { - return ¯o{ +func NewReceiverMacro(function string, argCount int, expander MacroExpander, opts ...MacroOpt) Macro { + m := ¯o{ function: function, argCount: argCount, expander: expander, receiverStyle: true} + for _, opt := range opts { + m = opt(m) + } + return m } // NewGlobalVarArgMacro creates a Macro for a global function with a variable arg count. -func NewGlobalVarArgMacro(function string, expander MacroExpander) Macro { - return ¯o{ +func NewGlobalVarArgMacro(function string, expander MacroExpander, opts ...MacroOpt) Macro { + m := ¯o{ function: function, expander: expander, varArgStyle: true} + for _, opt := range opts { + m = opt(m) + } + return m } // NewReceiverVarArgMacro creates a Macro for a receiver function matching a variable arg count. -func NewReceiverVarArgMacro(function string, expander MacroExpander) Macro { - return ¯o{ +func NewReceiverVarArgMacro(function string, expander MacroExpander, opts ...MacroOpt) Macro { + m := ¯o{ function: function, expander: expander, receiverStyle: true, varArgStyle: true} + for _, opt := range opts { + m = opt(m) + } + return m } // Macro interface for describing the function signature to match and the MacroExpander to apply. @@ -95,6 +131,8 @@ type macro struct { varArgStyle bool argCount int expander MacroExpander + doc string + examples []string } // Function returns the macro's function name (i.e. the function whose syntax it mimics). @@ -125,6 +163,22 @@ func (m *macro) MacroKey() string { return makeMacroKey(m.function, m.argCount, m.receiverStyle) } +func (m *macro) Documentation() *common.Doc { + examples := make([]*common.Doc, len(m.examples)) + for i, ex := range m.examples { + examples[i] = &common.Doc{ + Kind: common.DocExample, + Description: ex, + } + } + return &common.Doc{ + Kind: common.DocMacro, + Name: m.Function(), + Description: common.ParseDescription(m.doc), + Children: examples, + } +} + func makeMacroKey(name string, args int, receiverStyle bool) string { return fmt.Sprintf("%s:%d:%v", name, args, receiverStyle) } @@ -250,37 +304,139 @@ type ExprHelper interface { var ( // HasMacro expands "has(m.f)" which tests the presence of a field, avoiding the need to // specify the field as a string. - HasMacro = NewGlobalMacro(operators.Has, 1, MakeHas) + HasMacro = NewGlobalMacro(operators.Has, 1, MakeHas, + MacroDocs( + `check a protocol buffer message for the presence of a field, or check a map`, + `for the presence of a string key.`, + `Only map accesses using the select notation are supported.`), + MacroExamples( + common.MultilineDescription( + `// true if the 'address' field exists in the 'user' message`, + `has(user.address)`), + common.MultilineDescription( + `// test whether the 'key_name' is set on the map which defines it`, + `has({'key_name': 'value'}.key_name) // true`), + common.MultilineDescription( + `// test whether the 'id' field is set to a non-default value on the Expr{} message literal`, + `has(Expr{}.id) // false`), + )) // AllMacro expands "range.all(var, predicate)" into a comprehension which ensures that all // elements in the range satisfy the predicate. - AllMacro = NewReceiverMacro(operators.All, 2, MakeAll) + AllMacro = NewReceiverMacro(operators.All, 2, MakeAll, + MacroDocs(`tests whether all elements in the input list or all keys in a map`, + `satisfy the given predicate. The all macro behaves in a manner consistent with`, + `the Logical AND operator including in how it absorbs errors and short-circuits.`), + MacroExamples( + `[1, 2, 3].all(x, x > 0) // true`, + `[1, 2, 0].all(x, x > 0) // false`, + `['apple', 'banana', 'cherry'].all(fruit, fruit.size() > 3) // true`, + `[3.14, 2.71, 1.61].all(num, num < 3.0) // false`, + `{'a': 1, 'b': 2, 'c': 3}.all(key, key != 'b') // false`, + common.MultilineDescription( + `// an empty list or map as the range will result in a trivially true result`, + `[].all(x, x > 0) // true`), + )) // ExistsMacro expands "range.exists(var, predicate)" into a comprehension which ensures that // some element in the range satisfies the predicate. - ExistsMacro = NewReceiverMacro(operators.Exists, 2, MakeExists) + ExistsMacro = NewReceiverMacro(operators.Exists, 2, MakeExists, + MacroDocs(`tests whether any value in the list or any key in the map`, + `satisfies the predicate expression. The exists macro behaves in a manner`, + `consistent with the Logical OR operator including in how it absorbs errors and`, + `short-circuits.`), + MacroExamples( + `[1, 2, 3].exists(i, i % 2 != 0) // true`, + `[0, -1, 5].exists(num, num < 0) // true`, + `{'x': 'foo', 'y': 'bar'}.exists(key, key.startsWith('z')) // false`, + common.MultilineDescription( + `// an empty list or map as the range will result in a trivially false result`, + `[].exists(i, i > 0) // false`), + common.MultilineDescription( + `// test whether a key name equalling 'iss' exists in the map and the`, + `// value contains the substring 'cel.dev'`, + `// tokens = {'sub': 'me', 'iss': 'https://issuer.cel.dev'}`, + `tokens.exists(k, k == 'iss' && tokens[k].contains('cel.dev'))`), + )) // ExistsOneMacro expands "range.exists_one(var, predicate)", which is true if for exactly one // element in range the predicate holds. // Deprecated: Use ExistsOneMacroNew - ExistsOneMacro = NewReceiverMacro(operators.ExistsOne, 2, MakeExistsOne) + ExistsOneMacro = NewReceiverMacro(operators.ExistsOne, 2, MakeExistsOne, + MacroDocs(`tests whether exactly one list element or map key satisfies`, + `the predicate expression. This macro does not short-circuit in order to remain`, + `consistent with logical operators being the only operators which can absorb`, + `errors within CEL.`), + MacroExamples( + `[1, 2, 2].exists_one(i, i < 2) // true`, + `{'a': 'hello', 'aa': 'hellohello'}.exists_one(k, k.startsWith('a')) // false`, + `[1, 2, 3, 4].exists_one(num, num % 2 == 0) // false`, + common.MultilineDescription( + `// ensure exactly one key in the map ends in @acme.co`, + `{'wiley@acme.co': 'coyote', 'aa@milne.co': 'bear'}.exists_one(k, k.endsWith('@acme.co')) // true`), + )) // ExistsOneMacroNew expands "range.existsOne(var, predicate)", which is true if for exactly one // element in range the predicate holds. - ExistsOneMacroNew = NewReceiverMacro("existsOne", 2, MakeExistsOne) + ExistsOneMacroNew = NewReceiverMacro("existsOne", 2, MakeExistsOne, + MacroDocs( + `tests whether exactly one list element or map key satisfies the predicate`, + `expression. This macro does not short-circuit in order to remain consistent`, + `with logical operators being the only operators which can absorb errors`, + `within CEL.`), + MacroExamples( + `[1, 2, 2].existsOne(i, i < 2) // true`, + `{'a': 'hello', 'aa': 'hellohello'}.existsOne(k, k.startsWith('a')) // false`, + `[1, 2, 3, 4].existsOne(num, num % 2 == 0) // false`, + common.MultilineDescription( + `// ensure exactly one key in the map ends in @acme.co`, + `{'wiley@acme.co': 'coyote', 'aa@milne.co': 'bear'}.existsOne(k, k.endsWith('@acme.co')) // true`), + )) // MapMacro expands "range.map(var, function)" into a comprehension which applies the function // to each element in the range to produce a new list. - MapMacro = NewReceiverMacro(operators.Map, 2, MakeMap) + MapMacro = NewReceiverMacro(operators.Map, 2, MakeMap, + MacroDocs("the three-argument form of map transforms all elements in the input range."), + MacroExamples( + `[1, 2, 3].map(x, x * 2) // [2, 4, 6]`, + `[5, 10, 15].map(x, x / 5) // [1, 2, 3]`, + `['apple', 'banana'].map(fruit, fruit.upperAscii()) // ['APPLE', 'BANANA']`, + common.MultilineDescription( + `// Combine all map key-value pairs into a list`, + `{'hi': 'you', 'howzit': 'bruv'}.map(k,`, + ` k + ":" + {'hi': 'you', 'howzit': 'bruv'}[k]) // ['hi:you', 'howzit:bruv']`), + )) // MapFilterMacro expands "range.map(var, predicate, function)" into a comprehension which // first filters the elements in the range by the predicate, then applies the transform function // to produce a new list. - MapFilterMacro = NewReceiverMacro(operators.Map, 3, MakeMap) + MapFilterMacro = NewReceiverMacro(operators.Map, 3, MakeMap, + MacroDocs(`the four-argument form of the map transforms only elements which satisfy`, + `the predicate which is equivalent to chaining the filter and three-argument`, + `map macros together.`), + MacroExamples( + common.MultilineDescription( + `// multiply only numbers divisible two, by 2`, + `[1, 2, 3, 4].map(num, num % 2 == 0, num * 2) // [4, 8]`), + )) // FilterMacro expands "range.filter(var, predicate)" into a comprehension which filters // elements in the range, producing a new list from the elements that satisfy the predicate. - FilterMacro = NewReceiverMacro(operators.Filter, 2, MakeFilter) + FilterMacro = NewReceiverMacro(operators.Filter, 2, MakeFilter, + MacroDocs(`returns a list containing only the elements from the input list`, + `that satisfy the given predicate`), + MacroExamples( + `[1, 2, 3].filter(x, x > 1) // [2, 3]`, + `['cat', 'dog', 'bird', 'fish'].filter(pet, pet.size() == 3) // ['cat', 'dog']`, + `[{'a': 10, 'b': 5, 'c': 20}].map(m, m.filter(key, m[key] > 10)) // [['c']]`, + common.MultilineDescription( + `// filter a list to select only emails with the @cel.dev suffix`, + `['alice@buf.io', 'tristan@cel.dev'].filter(v, v.endsWith('@cel.dev')) // ['tristan@cel.dev']`), + common.MultilineDescription( + `// filter a map into a list, selecting only the values for keys that start with 'http-auth'`, + `{'http-auth-agent': 'secret', 'user-agent': 'mozilla'}.filter(k,`, + ` k.startsWith('http-auth')) // ['secret']`), + )) // AllMacros includes the list of all spec-supported macros. AllMacros = []Macro{ From 94b45f0cae030982e1dadff4173b96718c1128c9 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Fri, 11 Apr 2025 12:36:14 -0700 Subject: [PATCH 30/46] Prompt generation for AI-assisted authoring based on a CEL environment (#1160) * AI authoring prompt generation --- cel/BUILD.bazel | 6 + cel/decls.go | 20 ++ cel/env.go | 21 +- cel/macro.go | 30 ++- cel/prompt.go | 155 +++++++++++ cel/prompt_test.go | 73 ++++++ cel/templates/BUILD.bazel | 7 + cel/templates/authoring.tmpl | 56 ++++ cel/testdata/BUILD.bazel | 5 + cel/testdata/basic.prompt.md | 32 +++ cel/testdata/macros.prompt.md | 100 +++++++ cel/testdata/standard_env.prompt.md | 387 ++++++++++++++++++++++++++++ 12 files changed, 880 insertions(+), 12 deletions(-) create mode 100644 cel/prompt.go create mode 100644 cel/prompt_test.go create mode 100644 cel/templates/BUILD.bazel create mode 100644 cel/templates/authoring.tmpl create mode 100644 cel/testdata/basic.prompt.md create mode 100644 cel/testdata/macros.prompt.md create mode 100644 cel/testdata/standard_env.prompt.md diff --git a/cel/BUILD.bazel b/cel/BUILD.bazel index d89595821..4a0425a8e 100644 --- a/cel/BUILD.bazel +++ b/cel/BUILD.bazel @@ -18,8 +18,10 @@ go_library( "optimizer.go", "options.go", "program.go", + "prompt.go", "validator.go", ], + embedsrcs = ["//cel/templates"], importpath = "github.com/google/cel-go/cel", visibility = ["//visibility:public"], deps = [ @@ -65,6 +67,7 @@ go_test( "io_test.go", "inlining_test.go", "optimizer_test.go", + "prompt_test.go", "validator_test.go", ], data = [ @@ -73,6 +76,9 @@ go_test( embed = [ ":go_default_library", ], + embedsrcs = [ + "//cel/testdata:prompts", + ], deps = [ "//common/operators:go_default_library", "//common/overloads:go_default_library", diff --git a/cel/decls.go b/cel/decls.go index eedc909bb..3c7d891f7 100644 --- a/cel/decls.go +++ b/cel/decls.go @@ -148,6 +148,14 @@ func Variable(name string, t *Type) EnvOption { } } +// VariableWithDoc creates an instance of a variable declaration with a variable name, type, and doc string. +func VariableWithDoc(name string, t *Type, doc string) EnvOption { + return func(e *Env) (*Env, error) { + e.variables = append(e.variables, decls.NewVariableWithDoc(name, t, doc)) + return e, nil + } +} + // VariableDecls configures a set of fully defined cel.VariableDecl instances in the environment. func VariableDecls(vars ...*decls.VariableDecl) EnvOption { return func(e *Env) (*Env, error) { @@ -239,6 +247,13 @@ func FunctionDecls(funcs ...*decls.FunctionDecl) EnvOption { // FunctionOpt defines a functional option for configuring a function declaration. type FunctionOpt = decls.FunctionOpt +// FunctionDocs provides a general usage documentation for the function. +// +// Use OverloadExamples to provide example usage instructions for specific overloads. +func FunctionDocs(docs ...string) FunctionOpt { + return decls.FunctionDocs(docs...) +} + // SingletonUnaryBinding creates a singleton function definition to be used for all function overloads. // // Note, this approach works well if operand is expected to have a specific trait which it implements, @@ -312,6 +327,11 @@ func MemberOverload(overloadID string, args []*Type, resultType *Type, opts ...O // OverloadOpt is a functional option for configuring a function overload. type OverloadOpt = decls.OverloadOpt +// OverloadExamples configures an example of how to invoke the overload. +func OverloadExamples(docs ...string) OverloadOpt { + return decls.OverloadExamples(docs...) +} + // UnaryBinding provides the implementation of a unary overload. The provided function is protected by a runtime // type-guard which ensures runtime type agreement between the overload signature and runtime argument types. func UnaryBinding(binding functions.UnaryOp) OverloadOpt { diff --git a/cel/env.go b/cel/env.go index 3e2c62876..bb3014464 100644 --- a/cel/env.go +++ b/cel/env.go @@ -553,14 +553,27 @@ func (e *Env) HasFunction(functionName string) bool { return ok } -// Functions returns map of Functions, keyed by function name, that have been configured in the environment. +// Functions returns a shallow copy of the Functions, keyed by function name, that have been configured in the environment. func (e *Env) Functions() map[string]*decls.FunctionDecl { - return e.functions + shallowCopy := make(map[string]*decls.FunctionDecl, len(e.functions)) + for nm, fn := range e.functions { + shallowCopy[nm] = fn + } + return shallowCopy } -// Variables returns the set of variables associated with the environment. +// Variables returns a shallow copy of the variables associated with the environment. func (e *Env) Variables() []*decls.VariableDecl { - return e.variables[:] + shallowCopy := make([]*decls.VariableDecl, len(e.variables)) + copy(shallowCopy, e.variables) + return shallowCopy +} + +// Macros returns a shallow copy of macros associated with the environment. +func (e *Env) Macros() []Macro { + shallowCopy := make([]Macro, len(e.macros)) + copy(shallowCopy, e.macros) + return shallowCopy } // HasValidator returns whether a specific ASTValidator has been configured in the environment. diff --git a/cel/macro.go b/cel/macro.go index 4db1fd57a..3d3c5be1b 100644 --- a/cel/macro.go +++ b/cel/macro.go @@ -142,24 +142,38 @@ type MacroExprHelper interface { NewError(exprID int64, message string) *Error } +// MacroOpt defines a functional option for configuring macro behavior. +type MacroOpt = parser.MacroOpt + +// MacroDocs configures a list of strings into a multiline description for the macro. +func MacroDocs(docs ...string) MacroOpt { + return parser.MacroDocs(docs...) +} + +// MacroExamples configures a list of examples, either as a string or common.MultilineString, +// into an example set to be provided with the macro Documentation() call. +func MacroExamples(examples ...string) MacroOpt { + return parser.MacroExamples(examples...) +} + // GlobalMacro creates a Macro for a global function with the specified arg count. -func GlobalMacro(function string, argCount int, factory MacroFactory) Macro { - return parser.NewGlobalMacro(function, argCount, factory) +func GlobalMacro(function string, argCount int, factory MacroFactory, opts ...MacroOpt) Macro { + return parser.NewGlobalMacro(function, argCount, factory, opts...) } // ReceiverMacro creates a Macro for a receiver function matching the specified arg count. -func ReceiverMacro(function string, argCount int, factory MacroFactory) Macro { - return parser.NewReceiverMacro(function, argCount, factory) +func ReceiverMacro(function string, argCount int, factory MacroFactory, opts ...MacroOpt) Macro { + return parser.NewReceiverMacro(function, argCount, factory, opts...) } // GlobalVarArgMacro creates a Macro for a global function with a variable arg count. -func GlobalVarArgMacro(function string, factory MacroFactory) Macro { - return parser.NewGlobalVarArgMacro(function, factory) +func GlobalVarArgMacro(function string, factory MacroFactory, opts ...MacroOpt) Macro { + return parser.NewGlobalVarArgMacro(function, factory, opts...) } // ReceiverVarArgMacro creates a Macro for a receiver function matching a variable arg count. -func ReceiverVarArgMacro(function string, factory MacroFactory) Macro { - return parser.NewReceiverVarArgMacro(function, factory) +func ReceiverVarArgMacro(function string, factory MacroFactory, opts ...MacroOpt) Macro { + return parser.NewReceiverVarArgMacro(function, factory, opts...) } // NewGlobalMacro creates a Macro for a global function with the specified arg count. diff --git a/cel/prompt.go b/cel/prompt.go new file mode 100644 index 000000000..3f4ed1e1c --- /dev/null +++ b/cel/prompt.go @@ -0,0 +1,155 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cel + +import ( + _ "embed" + "sort" + "strings" + "text/template" + + "github.com/google/cel-go/common" + "github.com/google/cel-go/common/operators" + "github.com/google/cel-go/common/overloads" +) + +//go:embed templates/authoring.tmpl +var authoringPrompt string + +// AuthoringPrompt creates a prompt template from a CEL environment for the purpose of AI-assisted authoring. +func AuthoringPrompt(env *Env) (*Prompt, error) { + funcMap := template.FuncMap{ + "split": func(str string) []string { return strings.Split(str, "\n") }, + } + tmpl := template.New("cel").Funcs(funcMap) + tmpl, err := tmpl.Parse(authoringPrompt) + if err != nil { + return nil, err + } + return &Prompt{ + Persona: defaultPersona, + FormatRules: defaultFormatRules, + GeneralUsage: defaultGeneralUsage, + tmpl: tmpl, + env: env, + }, nil +} + +// Prompt represents the core components of an LLM prompt based on a CEL environment. +// +// All fields of the prompt may be overwritten / modified with support for rendering the +// prompt to a human-readable string. +type Prompt struct { + // Persona indicates something about the kind of user making the request + Persona string + + // FormatRules indicate how the LLM should generate its output + FormatRules string + + // GeneralUsage specifies additional context on how CEL should be used. + GeneralUsage string + + // tmpl is the text template base-configuration for rendering text. + tmpl *template.Template + + // env reference used to collect variables, functions, and macros available to the prompt. + env *Env +} + +type promptInst struct { + *Prompt + + Variables []*common.Doc + Macros []*common.Doc + Functions []*common.Doc + UserPrompt string +} + +// Render renders the user prompt with the associated context from the prompt template +// for use with LLM generators. +func (p *Prompt) Render(userPrompt string) string { + var buffer strings.Builder + vars := make([]*common.Doc, len(p.env.Variables())) + for i, v := range p.env.Variables() { + vars[i] = v.Documentation() + } + sort.SliceStable(vars, func(i, j int) bool { + return vars[i].Name < vars[j].Name + }) + macs := make([]*common.Doc, len(p.env.Macros())) + for i, m := range p.env.Macros() { + macs[i] = m.(common.Documentor).Documentation() + } + funcs := make([]*common.Doc, 0, len(p.env.Functions())) + for _, f := range p.env.Functions() { + if _, hidden := hiddenFunctions[f.Name()]; hidden { + continue + } + funcs = append(funcs, f.Documentation()) + } + sort.SliceStable(funcs, func(i, j int) bool { + return funcs[i].Name < funcs[j].Name + }) + inst := &promptInst{ + Prompt: p, + Variables: vars, + Macros: macs, + Functions: funcs, + UserPrompt: userPrompt} + p.tmpl.Execute(&buffer, inst) + return buffer.String() +} + +const ( + defaultPersona = `You are a software engineer with expertise in networking and application security +authoring boolean Common Expression Language (CEL) expressions to ensure firewall, +networking, authentication, and data access is only permitted when all conditions +are satisified.` + + defaultFormatRules = `Output your response as a CEL expression. + +Write the expression with the comment on the first line and the expression on the +subsequent lines. Format the expression using 80-character line limits commonly +found in C++ or Java code.` + + defaultGeneralUsage = `CEL supports Protocol Buffer and JSON types, as well as simple types and aggregate types. + +Simple types include bool, bytes, double, int, string, and uint: + +* double literals must always include a decimal point: 1.0, 3.5, -2.2 +* uint literals must be positive values suffixed with a 'u': 42u +* byte literals are strings prefixed with a 'b': b'1235' +* string literals can use either single quotes or double quotes: 'hello', "world" +* string literals can also be treated as raw strings that do not require any + escaping within the string by using the 'R' prefix: R"""quote: "hi" """ + +Aggregate types include list and map: + +* list literals consist of zero or more values between brackets: "['a', 'b', 'c']" +* map literal consist of colon-separated key-value pairs within braces: "{'key1': 1, 'key2': 2}" +* Only int, uint, string, and bool types are valid map keys. +* Maps containing HTTP headers must always use lower-cased string keys. + +Comments start with two-forward slashes followed by text and a newline.` +) + +var ( + hiddenFunctions = map[string]bool{ + overloads.DeprecatedIn: true, + operators.OldIn: true, + operators.OldNotStrictlyFalse: true, + operators.NotStrictlyFalse: true, + } +) diff --git a/cel/prompt_test.go b/cel/prompt_test.go new file mode 100644 index 000000000..5598cd9b8 --- /dev/null +++ b/cel/prompt_test.go @@ -0,0 +1,73 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cel + +import ( + _ "embed" + "testing" + + "github.com/google/cel-go/common/env" + "github.com/google/cel-go/test" +) + +//go:embed testdata/basic.prompt.md +var wantBasicPrompt string + +//go:embed testdata/macros.prompt.md +var wantMacrosPrompt string + +//go:embed testdata/standard_env.prompt.md +var wantStandardEnvPrompt string + +func TestPromptTemplate(t *testing.T) { + tests := []struct { + name string + envOpts []EnvOption + out string + }{ + { + name: "basic", + out: wantBasicPrompt, + }, + { + name: "macros", + envOpts: []EnvOption{Macros(StandardMacros...)}, + out: wantMacrosPrompt, + }, + { + name: "standard_env", + envOpts: []EnvOption{StdLib(StdLibSubset(env.NewLibrarySubset().SetDisableMacros(true)))}, + out: wantStandardEnvPrompt, + }, + } + + for _, tst := range tests { + tc := tst + t.Run(tc.name, func(t *testing.T) { + env, err := NewCustomEnv(tc.envOpts...) + if err != nil { + t.Fatalf("cel.NewCustomEnv() failed: %v", err) + } + prompt, err := AuthoringPrompt(env) + if err != nil { + t.Fatalf("cel.AuthoringPrompt() failed: %v", err) + } + out := prompt.Render("") + if !test.Compare(out, tc.out) { + t.Errorf("got %s, wanted %s", out, tc.out) + } + }) + } +} diff --git a/cel/templates/BUILD.bazel b/cel/templates/BUILD.bazel new file mode 100644 index 000000000..217024b2f --- /dev/null +++ b/cel/templates/BUILD.bazel @@ -0,0 +1,7 @@ +licenses(["notice"]) # Apache 2.0 + +filegroup( + name = "templates", + srcs = glob(["*.tmpl"]), + visibility = ["//visibility:public"], +) \ No newline at end of file diff --git a/cel/templates/authoring.tmpl b/cel/templates/authoring.tmpl new file mode 100644 index 000000000..d6b3da5c6 --- /dev/null +++ b/cel/templates/authoring.tmpl @@ -0,0 +1,56 @@ +{{define "variable"}}{{.Name}} is a {{.Type}} +{{- end -}} + +{{define "macro" -}} +{{.Name}} macro{{if .Description}} - {{range split .Description}}{{.}} {{end}} +{{end}} +{{range .Children}}{{range split .Description}} {{.}} +{{end}} +{{- end -}} +{{- end -}} + +{{define "overload" -}} +{{if .Children}}{{range .Children}}{{range split .Description}} {{.}} +{{end}} +{{- end -}} +{{else}} {{.Signature}} +{{end}} +{{- end -}} + +{{define "function" -}} +{{.Name}}{{if .Description}} - {{range split .Description}}{{.}} {{end}} +{{end}} +{{range .Children}}{{template "overload" .}}{{end}} +{{- end -}} + +{{.Persona}} + +{{.FormatRules}} + +{{if or .Variables .Macros .Functions -}} +Only use the following variables, macros, and functions in expressions. +{{if .Variables}} +Variables: + +{{range .Variables}}* {{template "variable" .}} +{{end -}} + +{{end -}} +{{if .Macros}} +Macros: + +{{range .Macros}}* {{template "macro" .}} +{{end -}} + +{{end -}} +{{if .Functions}} +Functions: + +{{range .Functions}}* {{template "function" .}} +{{end -}} + +{{end -}} +{{- end -}} +{{.GeneralUsage}} + +{{.UserPrompt}} diff --git a/cel/testdata/BUILD.bazel b/cel/testdata/BUILD.bazel index 8fffa25ca..e643fbcae 100644 --- a/cel/testdata/BUILD.bazel +++ b/cel/testdata/BUILD.bazel @@ -16,3 +16,8 @@ genrule( "--descriptor_set_out=$@ $(SRCS)"), tools = ["@com_google_protobuf//:protoc"], ) + +filegroup( + name = "prompts", + srcs = glob(["*.prompt.md"]), +) \ No newline at end of file diff --git a/cel/testdata/basic.prompt.md b/cel/testdata/basic.prompt.md new file mode 100644 index 000000000..5f09cef8c --- /dev/null +++ b/cel/testdata/basic.prompt.md @@ -0,0 +1,32 @@ +You are a software engineer with expertise in networking and application security +authoring boolean Common Expression Language (CEL) expressions to ensure firewall, +networking, authentication, and data access is only permitted when all conditions +are satisified. + +Output your response as a CEL expression. + +Write the expression with the comment on the first line and the expression on the +subsequent lines. Format the expression using 80-character line limits commonly +found in C++ or Java code. + +CEL supports Protocol Buffer and JSON types, as well as simple types and aggregate types. + +Simple types include bool, bytes, double, int, string, and uint: + +* double literals must always include a decimal point: 1.0, 3.5, -2.2 +* uint literals must be positive values suffixed with a 'u': 42u +* byte literals are strings prefixed with a 'b': b'1235' +* string literals can use either single quotes or double quotes: 'hello', "world" +* string literals can also be treated as raw strings that do not require any + escaping within the string by using the 'R' prefix: R"""quote: "hi" """ + +Aggregate types include list and map: + +* list literals consist of zero or more values between brackets: "['a', 'b', 'c']" +* map literal consist of colon-separated key-value pairs within braces: "{'key1': 1, 'key2': 2}" +* Only int, uint, string, and bool types are valid map keys. +* Maps containing HTTP headers must always use lower-cased string keys. + +Comments start with two-forward slashes followed by text and a newline. + + diff --git a/cel/testdata/macros.prompt.md b/cel/testdata/macros.prompt.md new file mode 100644 index 000000000..9e6f0934a --- /dev/null +++ b/cel/testdata/macros.prompt.md @@ -0,0 +1,100 @@ +You are a software engineer with expertise in networking and application security +authoring boolean Common Expression Language (CEL) expressions to ensure firewall, +networking, authentication, and data access is only permitted when all conditions +are satisified. + +Output your response as a CEL expression. + +Write the expression with the comment on the first line and the expression on the +subsequent lines. Format the expression using 80-character line limits commonly +found in C++ or Java code. + +Only use the following variables, macros, and functions in expressions. + +Macros: + +* has macro - check a protocol buffer message for the presence of a field, or check a map for the presence of a string key. Only map accesses using the select notation are supported. + + // true if the 'address' field exists in the 'user' message + has(user.address) + // test whether the 'key_name' is set on the map which defines it + has({'key_name': 'value'}.key_name) // true + // test whether the 'id' field is set to a non-default value on the Expr{} message literal + has(Expr{}.id) // false + +* all macro - tests whether all elements in the input list or all keys in a map satisfy the given predicate. The all macro behaves in a manner consistent with the Logical AND operator including in how it absorbs errors and short-circuits. + + [1, 2, 3].all(x, x > 0) // true + [1, 2, 0].all(x, x > 0) // false + ['apple', 'banana', 'cherry'].all(fruit, fruit.size() > 3) // true + [3.14, 2.71, 1.61].all(num, num < 3.0) // false + {'a': 1, 'b': 2, 'c': 3}.all(key, key != 'b') // false + // an empty list or map as the range will result in a trivially true result + [].all(x, x > 0) // true + +* exists macro - tests whether any value in the list or any key in the map satisfies the predicate expression. The exists macro behaves in a manner consistent with the Logical OR operator including in how it absorbs errors and short-circuits. + + [1, 2, 3].exists(i, i % 2 != 0) // true + [0, -1, 5].exists(num, num < 0) // true + {'x': 'foo', 'y': 'bar'}.exists(key, key.startsWith('z')) // false + // an empty list or map as the range will result in a trivially false result + [].exists(i, i > 0) // false + // test whether a key name equalling 'iss' exists in the map and the + // value contains the substring 'cel.dev' + // tokens = {'sub': 'me', 'iss': 'https://issuer.cel.dev'} + tokens.exists(k, k == 'iss' && tokens[k].contains('cel.dev')) + +* exists_one macro - tests whether exactly one list element or map key satisfies the predicate expression. This macro does not short-circuit in order to remain consistent with logical operators being the only operators which can absorb errors within CEL. + + [1, 2, 2].exists_one(i, i < 2) // true + {'a': 'hello', 'aa': 'hellohello'}.exists_one(k, k.startsWith('a')) // false + [1, 2, 3, 4].exists_one(num, num % 2 == 0) // false + // ensure exactly one key in the map ends in @acme.co + {'wiley@acme.co': 'coyote', 'aa@milne.co': 'bear'}.exists_one(k, k.endsWith('@acme.co')) // true + +* map macro - the three-argument form of map transforms all elements in the input range. + + [1, 2, 3].map(x, x * 2) // [2, 4, 6] + [5, 10, 15].map(x, x / 5) // [1, 2, 3] + ['apple', 'banana'].map(fruit, fruit.upperAscii()) // ['APPLE', 'BANANA'] + // Combine all map key-value pairs into a list + {'hi': 'you', 'howzit': 'bruv'}.map(k, + k + ":" + {'hi': 'you', 'howzit': 'bruv'}[k]) // ['hi:you', 'howzit:bruv'] + +* map macro - the four-argument form of the map transforms only elements which satisfy the predicate which is equivalent to chaining the filter and three-argument map macros together. + + // multiply only numbers divisible two, by 2 + [1, 2, 3, 4].map(num, num % 2 == 0, num * 2) // [4, 8] + +* filter macro - returns a list containing only the elements from the input list that satisfy the given predicate + + [1, 2, 3].filter(x, x > 1) // [2, 3] + ['cat', 'dog', 'bird', 'fish'].filter(pet, pet.size() == 3) // ['cat', 'dog'] + [{'a': 10, 'b': 5, 'c': 20}].map(m, m.filter(key, m[key] > 10)) // [['c']] + // filter a list to select only emails with the @cel.dev suffix + ['alice@buf.io', 'tristan@cel.dev'].filter(v, v.endsWith('@cel.dev')) // ['tristan@cel.dev'] + // filter a map into a list, selecting only the values for keys that start with 'http-auth' + {'http-auth-agent': 'secret', 'user-agent': 'mozilla'}.filter(k, + k.startsWith('http-auth')) // ['secret'] + +CEL supports Protocol Buffer and JSON types, as well as simple types and aggregate types. + +Simple types include bool, bytes, double, int, string, and uint: + +* double literals must always include a decimal point: 1.0, 3.5, -2.2 +* uint literals must be positive values suffixed with a 'u': 42u +* byte literals are strings prefixed with a 'b': b'1235' +* string literals can use either single quotes or double quotes: 'hello', "world" +* string literals can also be treated as raw strings that do not require any + escaping within the string by using the 'R' prefix: R"""quote: "hi" """ + +Aggregate types include list and map: + +* list literals consist of zero or more values between brackets: "['a', 'b', 'c']" +* map literal consist of colon-separated key-value pairs within braces: "{'key1': 1, 'key2': 2}" +* Only int, uint, string, and bool types are valid map keys. +* Maps containing HTTP headers must always use lower-cased string keys. + +Comments start with two-forward slashes followed by text and a newline. + + diff --git a/cel/testdata/standard_env.prompt.md b/cel/testdata/standard_env.prompt.md new file mode 100644 index 000000000..b7a3010b0 --- /dev/null +++ b/cel/testdata/standard_env.prompt.md @@ -0,0 +1,387 @@ +You are a software engineer with expertise in networking and application security +authoring boolean Common Expression Language (CEL) expressions to ensure firewall, +networking, authentication, and data access is only permitted when all conditions +are satisified. + +Output your response as a CEL expression. + +Write the expression with the comment on the first line and the expression on the +subsequent lines. Format the expression using 80-character line limits commonly +found in C++ or Java code. + +Only use the following variables, macros, and functions in expressions. + +Variables: + +* bool is a type +* bytes is a type +* double is a type +* google.protobuf.Duration is a type +* google.protobuf.Timestamp is a type +* int is a type +* list is a type +* map is a type +* null_type is a type +* string is a type +* type is a type +* uint is a type + +Functions: + +* !_ - logically negate a boolean value. + + !true // false + !false // true + !error // error + +* -_ - negate a numeric value + + -(3.14) // -3.14 + -(5) // -5 + +* @in - test whether a value exists in a list, or a key exists in a map + + 2 in [1, 2, 3] // true + "a" in ["b", "c"] // false + 'key1' in {'key1': 'value1', 'key2': 'value2'} // true + 3 in {1: "one", 2: "two"} // false + +* _!=_ - compare two values of the same type for inequality + + 1 != 2 // true + "a" != "a" // false + 3.0 != 3.1 // true + +* _%_ - compute the modulus of one integer into another + + 3 % 2 // 1 + 6u % 3u // 0u + +* _&&_ - logically AND two boolean values. Errors and unknown values are valid inputs and will not halt evaluation. + + true && true // true + true && false // false + error && true // error + error && false // false + +* _*_ - multiply two numbers + + 3.5 * 40.0 // 140.0 + -2 * 6 // -12 + 13u * 3u // 39u + +* _+_ - adds two numeric values or concatenates two strings, bytes, or lists. + + b'hi' + bytes('ya') // b'hiya' + 3.14 + 1.59 // 4.73 + duration('1m') + duration('1s') // duration('1m1s') + duration('24h') + timestamp('2023-01-01T00:00:00Z') // timestamp('2023-01-02T00:00:00Z') + timestamp('2023-01-01T00:00:00Z') + duration('24h1m2s') // timestamp('2023-01-02T00:01:02Z') + 1 + 2 // 3 + [1] + [2, 3] // [1, 2, 3] + "Hello, " + "world!" // "Hello, world!" + 22u + 33u // 55u + +* _-_ - subtract two numbers, or two time-related values + + 10.5 - 2.0 // 8.5 + duration('1m') - duration('1s') // duration('59s') + 5 - 3 // 2 + timestamp('2023-01-10T12:00:00Z') + - duration('12h') // timestamp('2023-01-10T00:00:00Z') + timestamp('2023-01-10T12:00:00Z') + - timestamp('2023-01-10T00:00:00Z') // duration('12h') + // the subtraction result must be positive, otherwise an overflow + // error is generated. + 42u - 3u // 39u + +* _/_ - divide two numbers + + 7.0 / 2.0 // 3.5 + 10 / 2 // 5 + 42u / 2u // 21u + +* _<=_ - compare two values and return true if the first value is less than or equal to the second + + false <= true // true + -2 <= 3 // true + 1 <= 1.1 // true + 1 <= 2u // true + -1 <= 0u // true + 1u <= 2u // true + 1u <= 1.0 // true + 1u <= 1.1 // true + 1u <= 23 // true + 2.0 <= 2.4 // true + 2.1 <= 3 // true + 2.0 <= 2u // true + -1.0 <= 1u // true + 'a' <= 'b' // true + 'a' <= 'a' // true + 'cat' <= 'cab' // false + b'hello' <= b'world' // true + timestamp('2001-01-01T02:03:04Z') <= timestamp('2002-02-02T02:03:04Z') // true + duration('1ms') <= duration('1s') // true + +* _<_ - compare two values and return true if the first value is less than the second + + false < true // true + -2 < 3 // true + 1 < 0 // false + 1 < 1.1 // true + 1 < 2u // true + 1u < 2u // true + 1u < 0.9 // false + 1u < 23 // true + 1u < -1 // false + 2.0 < 2.4 // true + 2.1 < 3 // true + 2.3 < 2u // false + -1.0 < 1u // true + 'a' < 'b' // true + 'cat' < 'cab' // false + b'hello' < b'world' // true + timestamp('2001-01-01T02:03:04Z') < timestamp('2002-02-02T02:03:04Z') // true + duration('1ms') < duration('1s') // true + +* _==_ - compare two values of the same type for equality + + 1 == 1 // true + 'hello' == 'world' // false + bytes('hello') == b'hello' // true + duration('1h') == duration('60m') // true + dyn(3.0) == 3 // true + +* _>=_ - compare two values and return true if the first value is greater than or equal to the second + + true >= false // true + 3 >= -2 // true + 2 >= 1.1 // true + 1 >= 1.0 // true + 3 >= 2u // true + 2u >= 1u // true + 2u >= 1.9 // true + 23u >= 1 // true + 1u >= 1 // true + 2.4 >= 2.0 // true + 3.1 >= 3 // true + 2.3 >= 2u // true + 'b' >= 'a' // true + b'world' >= b'hello' // true + timestamp('2001-01-01T02:03:04Z') >= timestamp('2001-01-01T02:03:04Z') // true + duration('60s') >= duration('1m') // true + +* _>_ - compare two values and return true if the first value is greater than the second + + true > false // true + 3 > -2 // true + 2 > 1.1 // true + 3 > 2u // true + 2u > 1u // true + 2u > 1.9 // true + 23u > 1 // true + 0u > -1 // true + 2.4 > 2.0 // true + 3.1 > 3 // true + 3.0 > 3 // false + 2.3 > 2u // true + 'b' > 'a' // true + b'world' > b'hello' // true + timestamp('2002-02-02T02:03:04Z') > timestamp('2001-01-01T02:03:04Z') // true + duration('1ms') > duration('1us') // true + +* _?_:_ - The ternary operator tests a boolean predicate and returns the left-hand side (truthy) expression if true, or the right-hand side (falsy) expression if false + + 'hello'.contains('lo') ? 'hi' : 'bye' // 'hi' + 32 % 3 == 0 ? 'divisible' : 'not divisible' // 'not divisible' + +* _[_] - select a value from a list by index, or value from a map by key + + [1, 2, 3][1] // 2 + {'key': 'value'}['key'] // 'value' + {'key': 'value'}['missing'] // error + +* _||_ - logically OR two boolean values. Errors and unknown values are valid inputs and will not halt evaluation. + + true || false // true + false || false // false + error || true // true + error || error // true + +* bool - convert a value to a boolean + + bool(true) // true + bool('true') // true + bool('false') // false + +* bytes - convert a value to bytes + + bytes(b'abc') // b'abc' + bytes('hello') // b'hello' + +* contains - test whether a string contains a substring + + 'hello world'.contains('o w') // true + 'hello world'.contains('goodbye') // false + +* double - convert a value to a double + + double(1.23) // 1.23 + double(123) // 123.0 + double('1.23') // 1.23 + double(123u) // 123.0 + +* duration - convert a value to a google.protobuf.Duration + + duration(duration('1s')) // duration('1s') + duration(int) -> google.protobuf.Duration + duration('1h2m3s') // duration('3723s') + +* dyn - indicate that the type is dynamic for type-checking purposes + + dyn(1) // 1 + +* endsWith - test whether a string ends with a substring suffix + + 'hello world'.endsWith('world') // true + 'hello world'.endsWith('hello') // false + +* getDate - get the 1-based day of the month from a timestamp, UTC unless an IANA timezone is specified. + + timestamp('2023-07-14T10:30:45.123Z').getDate() // 14 + timestamp('2023-07-01T05:00:00Z').getDate('America/Los_Angeles') // 30 + +* getDayOfMonth - get the 0-based day of the month from a timestamp, UTC unless an IANA timezone is specified. + + timestamp('2023-07-14T10:30:45.123Z').getDayOfMonth() // 13 + timestamp('2023-07-01T05:00:00Z').getDayOfMonth('America/Los_Angeles') // 29 + +* getDayOfWeek - get the 0-based day of the week from a timestamp, UTC unless an IANA timezone is specified. + + timestamp('2023-07-14T10:30:45.123Z').getDayOfWeek() // 5 + timestamp('2023-07-16T05:00:00Z').getDayOfWeek('America/Los_Angeles') // 6 + +* getDayOfYear - get the 0-based day of the year from a timestamp, UTC unless an IANA timezone is specified. + + timestamp('2023-01-02T00:00:00Z').getDayOfYear() // 1 + timestamp('2023-01-01T05:00:00Z').getDayOfYear('America/Los_Angeles') // 364 + +* getFullYear - get the 0-based full year from a timestamp, UTC unless an IANA timezone is specified. + + timestamp('2023-07-14T10:30:45.123Z').getFullYear() // 2023 + timestamp('2023-01-01T05:30:00Z').getFullYear('-08:00') // 2022 + +* getHours - get the hours portion from a timestamp, or convert a duration to hours + + timestamp('2023-07-14T10:30:45.123Z').getHours() // 10 + timestamp('2023-07-14T10:30:45.123Z').getHours('America/Los_Angeles') // 2 + duration('3723s').getHours() // 1 + +* getMilliseconds - get the milliseconds portion from a timestamp + + timestamp('2023-07-14T10:30:45.123Z').getMilliseconds() // 123 + timestamp('2023-07-14T10:30:45.123Z').getMilliseconds('America/Los_Angeles') // 123 + google.protobuf.Duration.getMilliseconds() -> int + +* getMinutes - get the minutes portion from a timestamp, or convert a duration to minutes + + timestamp('2023-07-14T10:30:45.123Z').getMinutes() // 30 + timestamp('2023-07-14T10:30:45.123Z').getMinutes('America/Los_Angeles') // 30 + duration('3723s').getMinutes() // 62 + +* getMonth - get the 0-based month from a timestamp, UTC unless an IANA timezone is specified. + + timestamp('2023-07-14T10:30:45.123Z').getMonth() // 6 + timestamp('2023-01-01T05:30:00Z').getMonth('America/Los_Angeles') // 11 + +* getSeconds - get the seconds portion from a timestamp, or convert a duration to seconds + + timestamp('2023-07-14T10:30:45.123Z').getSeconds() // 45 + timestamp('2023-07-14T10:30:45.123Z').getSeconds('America/Los_Angeles') // 45 + duration('3723.456s').getSeconds() // 3723 + +* int - convert a value to an int + + int(123) // 123 + int(123.45) // 123 + int(duration('1s')) // 1000000000 + int('123') // 123 + int('-456') // -456 + int(timestamp('1970-01-01T00:00:01Z')) // 1 + int(123u) // 123 + +* matches - test whether a string matches an RE2 regular expression + + matches('123-456', '^[0-9]+(-[0-9]+)?$') // true + matches('hello', '^h.*o$') // true + '123-456'.matches('^[0-9]+(-[0-9]+)?$') // true + 'hello'.matches('^h.*o$') // true + +* size - compute the size of a list or map, the number of characters in a string, or the number of bytes in a sequence + + size(b'123') // 3 + b'123'.size() // 3 + size([1, 2, 3]) // 3 + [1, 2, 3].size() // 3 + size({'a': 1, 'b': 2}) // 2 + {'a': 1, 'b': 2}.size() // 2 + size('hello') // 5 + 'hello'.size() // 5 + +* startsWith - test whether a string starts with a substring prefix + + 'hello world'.startsWith('hello') // true + 'hello world'.startsWith('world') // false + +* string - convert a value to a string + + string('hello') // 'hello' + string(true) // 'true' + string(b'hello') // 'hello' + string(-1.23e4) // '-12300' + string(duration('1h30m')) // '5400s' + string(-123) // '-123' + string(timestamp('1970-01-01T00:00:00Z')) // '1970-01-01T00:00:00Z' + string(123u) // '123' + +* timestamp - convert a value to a google.protobuf.Timestamp + + timestamp(timestamp('2023-01-01T00:00:00Z')) // timestamp('2023-01-01T00:00:00Z') + timestamp(1) // timestamp('1970-01-01T00:00:01Z') + timestamp('2025-01-01T12:34:56Z') // timestamp('2025-01-01T12:34:56Z') + +* type - convert a value to its type identifier + + type(1) // int + type('hello') // string + type(int) // type + type(type) // type + +* uint - convert a value to a uint + + uint(123u) // 123u + uint(123.45) // 123u + uint(123) // 123u + uint('123') // 123u + +CEL supports Protocol Buffer and JSON types, as well as simple types and aggregate types. + +Simple types include bool, bytes, double, int, string, and uint: + +* double literals must always include a decimal point: 1.0, 3.5, -2.2 +* uint literals must be positive values suffixed with a 'u': 42u +* byte literals are strings prefixed with a 'b': b'1235' +* string literals can use either single quotes or double quotes: 'hello', "world" +* string literals can also be treated as raw strings that do not require any + escaping within the string by using the 'R' prefix: R"""quote: "hi" """ + +Aggregate types include list and map: + +* list literals consist of zero or more values between brackets: "['a', 'b', 'c']" +* map literal consist of colon-separated key-value pairs within braces: "{'key1': 1, 'key2': 2}" +* Only int, uint, string, and bool types are valid map keys. +* Maps containing HTTP headers must always use lower-cased string keys. + +Comments start with two-forward slashes followed by text and a newline. + + From 6b7ecea793fc4dac4d9835907848044e169fdcde Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Fri, 11 Apr 2025 12:36:26 -0700 Subject: [PATCH 31/46] Remove non-functional optional test in field selection (#1161) --- cel/cel_test.go | 16 ++++++++++++++++ interpreter/interpretable.go | 3 --- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/cel/cel_test.go b/cel/cel_test.go index 3ff459756..97e4f9f0f 100644 --- a/cel/cel_test.go +++ b/cel/cel_test.go @@ -2669,6 +2669,22 @@ func TestOptionalValuesEval(t *testing.T) { in map[string]any out any }{ + { + expr: `has({'foo': optional.none()}.foo)`, + out: types.True, + }, + { + expr: `has({'foo': optional.none()}.foo.value)`, + out: types.False, + }, + { + expr: `has({?'foo': optional.none()}.foo)`, + out: types.False, + }, + { + expr: `has({?'foo': optional.none()}.foo.value)`, + out: "no such key: foo", + }, { expr: `{}.?invalid`, out: types.OptionalNone, diff --git a/interpreter/interpretable.go b/interpreter/interpretable.go index 1990ce017..04bbf3ffe 100644 --- a/interpreter/interpretable.go +++ b/interpreter/interpretable.go @@ -194,9 +194,6 @@ func (q *testOnlyQualifier) Qualify(vars Activation, obj any) (any, error) { if unk, isUnk := out.(types.Unknown); isUnk { return unk, nil } - if opt, isOpt := out.(types.Optional); isOpt { - return opt.HasValue(), nil - } return present, nil } From 0f9133d701cef8ae0a0f43fa4be56428b5b3c1a5 Mon Sep 17 00:00:00 2001 From: Andreas Rohner Date: Mon, 14 Apr 2025 19:46:07 +0200 Subject: [PATCH 32/46] Add LateFunctionBinding declaration and fix constant folding (#1117) Adds the declaration for LateFunctionBindings, which can be used to indicate to the Runtime and the Optimizers that the function will be bound at runtime through the Activation. This lets the constant folding optimizer know that the function potentially has side effects and cannot be folded. Without this the optimization will fail with an error for late bound functions where all arguments are constants. The implementation for late bound functions will be added in a subsequent commit. --- cel/decls.go | 6 +++ cel/folding.go | 24 +++++++++-- cel/folding_test.go | 85 +++++++++++++++++++++++++++++++++++++ common/decls/decls.go | 53 +++++++++++++++++++++++ common/decls/decls_test.go | 86 ++++++++++++++++++++++++++++++++++++++ 5 files changed, 251 insertions(+), 3 deletions(-) diff --git a/cel/decls.go b/cel/decls.go index 3c7d891f7..ab5a0b8c3 100644 --- a/cel/decls.go +++ b/cel/decls.go @@ -350,6 +350,12 @@ func FunctionBinding(binding functions.FunctionOp) OverloadOpt { return decls.FunctionBinding(binding) } +// LateFunctionBinding indicates that the function has a binding which is not known at compile time. +// This is useful for functions which have side-effects or are not deterministically computable. +func LateFunctionBinding() OverloadOpt { + return decls.LateFunctionBinding() +} + // OverloadIsNonStrict enables the function to be called with error and unknown argument values. // // Note: do not use this option unless absoluately necessary as it should be an uncommon feature. diff --git a/cel/folding.go b/cel/folding.go index d7060896d..0c7ecc616 100644 --- a/cel/folding.go +++ b/cel/folding.go @@ -68,7 +68,8 @@ func (opt *constantFoldingOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST) // Walk the list of foldable expression and continue to fold until there are no more folds left. // All of the fold candidates returned by the constantExprMatcher should succeed unless there's // a logic bug with the selection of expressions. - foldableExprs := ast.MatchDescendants(root, constantExprMatcher) + constantExprMatcherCapture := func(e ast.NavigableExpr) bool { return constantExprMatcher(ctx, a, e) } + foldableExprs := ast.MatchDescendants(root, constantExprMatcherCapture) foldCount := 0 for len(foldableExprs) != 0 && foldCount < opt.maxFoldIterations { for _, fold := range foldableExprs { @@ -77,6 +78,10 @@ func (opt *constantFoldingOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST) if fold.Kind() == ast.CallKind && maybePruneBranches(ctx, fold) { continue } + // Late-bound function calls cannot be folded. + if fold.Kind() == ast.CallKind && isLateBoundFunctionCall(ctx, a, fold) { + continue + } // Otherwise, assume all context is needed to evaluate the expression. err := tryFold(ctx, a, fold) if err != nil { @@ -85,7 +90,7 @@ func (opt *constantFoldingOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST) } } foldCount++ - foldableExprs = ast.MatchDescendants(root, constantExprMatcher) + foldableExprs = ast.MatchDescendants(root, constantExprMatcherCapture) } // Once all of the constants have been folded, try to run through the remaining comprehensions // one last time. In this case, there's no guarantee they'll run, so we only update the @@ -139,6 +144,15 @@ func tryFold(ctx *OptimizerContext, a *ast.AST, expr ast.Expr) error { return nil } +func isLateBoundFunctionCall(ctx *OptimizerContext, a *ast.AST, expr ast.Expr) bool { + call := expr.AsCall() + function := ctx.Functions()[call.FunctionName()] + if function == nil { + return false + } + return function.HasLateBinding() +} + // maybePruneBranches inspects the non-strict call expression to determine whether // a branch can be removed. Evaluation will naturally prune logical and / or calls, // but conditional will not be pruned cleanly, so this is one small area where the @@ -455,7 +469,7 @@ func adaptLiteral(ctx *OptimizerContext, val ref.Val) (ast.Expr, error) { // Only comprehensions which are not nested are included as possible constant folds, and only // if all variables referenced in the comprehension stack exist are only iteration or // accumulation variables. -func constantExprMatcher(e ast.NavigableExpr) bool { +func constantExprMatcher(ctx *OptimizerContext, a *ast.AST, e ast.NavigableExpr) bool { switch e.Kind() { case ast.CallKind: return constantCallMatcher(e) @@ -477,6 +491,10 @@ func constantExprMatcher(e ast.NavigableExpr) bool { if e.Kind() == ast.IdentKind && !vars[e.AsIdent()] { constantExprs = false } + // Late-bound function calls cannot be folded. + if e.Kind() == ast.CallKind && isLateBoundFunctionCall(ctx, a, e) { + constantExprs = false + } }) ast.PreOrderVisit(e, visitor) return constantExprs diff --git a/cel/folding_test.go b/cel/folding_test.go index 3f24f50ee..c85a36073 100644 --- a/cel/folding_test.go +++ b/cel/folding_test.go @@ -17,12 +17,14 @@ package cel import ( "reflect" "sort" + "strings" "testing" "google.golang.org/protobuf/encoding/prototext" "google.golang.org/protobuf/proto" "github.com/google/cel-go/common/ast" + "github.com/google/cel-go/common/types/ref" proto3pb "github.com/google/cel-go/test/proto3pb" exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" @@ -313,6 +315,89 @@ func TestConstantFoldingOptimizer(t *testing.T) { } } +func TestConstantFoldingCallsWithSideEffects(t *testing.T) { + tests := []struct { + expr string + folded string + error string + }{ + { + expr: `noSideEffect(3)`, + folded: `3`, + }, + { + expr: `withSideEffect(3)`, + folded: `withSideEffect(3)`, + }, + { + expr: `[{}, {"a": 1}, {"b": 2}].exists(i, has(i.b) && withSideEffect(i.b) == 1)`, + folded: `[{}, {"a": 1}, {"b": 2}].exists(i, has(i.b) && withSideEffect(i.b) == 1)`, + }, + { + expr: `[{}, {"a": 1}, {"b": 2}].exists(i, has(i.b) && noSideEffect(i.b) == 2)`, + folded: `true`, + }, + { + expr: `noImpl(3)`, + error: `constant-folding evaluation failed: no such overload: noImpl`, + }, + } + e, err := NewEnv( + OptionalTypes(), + EnableMacroCallTracking(), + Function("noSideEffect", + Overload("noSideEffect_int_int", + []*Type{IntType}, + IntType, FunctionBinding(func(args ...ref.Val) ref.Val { + return args[0] + }))), + Function("withSideEffect", + Overload("withSideEffect_int_int", + []*Type{IntType}, + IntType, LateFunctionBinding())), + Function("noImpl", + Overload("noImpl_int_int", + []*Type{IntType}, + IntType)), + ) + if err != nil { + t.Fatalf("NewEnv() failed: %v", err) + } + for _, tst := range tests { + tc := tst + t.Run(tc.expr, func(t *testing.T) { + checked, iss := e.Compile(tc.expr) + if iss.Err() != nil { + t.Fatalf("Compile() failed: %v", iss.Err()) + } + folder, err := NewConstantFoldingOptimizer() + if err != nil { + t.Fatalf("NewConstantFoldingOptimizer() failed: %v", err) + } + opt := NewStaticOptimizer(folder) + optimized, iss := opt.Optimize(e, checked) + if tc.error != "" { + if iss.Err() == nil { + t.Errorf("got nil, wanted error containing %q", tc.error) + } else if !strings.Contains(iss.Err().Error(), tc.error) { + t.Errorf("got %q, wanted error containing %q", iss.Err().Error(), tc.error) + } + return + } + if iss.Err() != nil { + t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err()) + } + folded, err := AstToString(optimized) + if err != nil { + t.Fatalf("AstToString() failed: %v", err) + } + if folded != tc.folded { + t.Errorf("got %q, wanted %q", folded, tc.folded) + } + }) + } +} + func TestConstantFoldingOptimizerMacroElimination(t *testing.T) { tests := []struct { expr string diff --git a/common/decls/decls.go b/common/decls/decls.go index a0fa6bcbd..8a43a7eef 100644 --- a/common/decls/decls.go +++ b/common/decls/decls.go @@ -281,6 +281,9 @@ func (f *FunctionDecl) AddOverload(overload *OverloadDecl) error { } return fmt.Errorf("overload redefinition in function. %s: %s has multiple definitions", f.Name(), oID) } + if overload.HasLateBinding() != o.HasLateBinding() { + return fmt.Errorf("overload with late binding cannot be added to function %s: cannot mix late and non-late bindings", f.Name()) + } } f.overloadOrdinals = append(f.overloadOrdinals, overload.ID()) f.overloads[overload.ID()] = overload @@ -300,6 +303,19 @@ func (f *FunctionDecl) OverloadDecls() []*OverloadDecl { return overloads } +// Returns true if the function has late bindings. A function cannot mix late bindings with other bindings. +func (f *FunctionDecl) HasLateBinding() bool { + if f == nil { + return false + } + for _, oID := range f.overloadOrdinals { + if f.overloads[oID].HasLateBinding() { + return true + } + } + return false +} + // Bindings produces a set of function bindings, if any are defined. func (f *FunctionDecl) Bindings() ([]*functions.Overload, error) { var emptySet []*functions.Overload @@ -308,8 +324,10 @@ func (f *FunctionDecl) Bindings() ([]*functions.Overload, error) { } overloads := []*functions.Overload{} nonStrict := false + hasLateBinding := false for _, oID := range f.overloadOrdinals { o := f.overloads[oID] + hasLateBinding = hasLateBinding || o.HasLateBinding() if o.hasBinding() { overload := &functions.Overload{ Operator: o.ID(), @@ -327,6 +345,9 @@ func (f *FunctionDecl) Bindings() ([]*functions.Overload, error) { if len(overloads) != 0 { return nil, fmt.Errorf("singleton function incompatible with specialized overloads: %s", f.Name()) } + if hasLateBinding { + return nil, fmt.Errorf("singleton function incompatible with late bindings: %s", f.Name()) + } overloads = []*functions.Overload{ { Operator: f.Name(), @@ -576,6 +597,9 @@ type OverloadDecl struct { argTypes []*types.Type resultType *types.Type isMemberFunction bool + // hasLateBinding indicates that the function has a binding which is not known at compile time. + // This is useful for functions which have side-effects or are not deterministically computable. + hasLateBinding bool // nonStrict indicates that the function will accept error and unknown arguments as inputs. nonStrict bool // operandTrait indicates whether the member argument should have a specific type-trait. @@ -640,6 +664,14 @@ func (o *OverloadDecl) IsNonStrict() bool { return o.nonStrict } +// HasLateBinding returns whether the overload has a binding which is not known at compile time. +func (o *OverloadDecl) HasLateBinding() bool { + if o == nil { + return false + } + return o.hasLateBinding +} + // OperandTrait returns the trait mask of the first operand to the overload call, e.g. // `traits.Indexer` func (o *OverloadDecl) OperandTrait() int { @@ -816,6 +848,9 @@ func UnaryBinding(binding functions.UnaryOp) OverloadOpt { if len(o.ArgTypes()) != 1 { return nil, fmt.Errorf("unary function bound to non-unary overload: %s", o.ID()) } + if o.hasLateBinding { + return nil, fmt.Errorf("overload already has a late binding: %s", o.ID()) + } o.unaryOp = binding return o, nil } @@ -831,6 +866,9 @@ func BinaryBinding(binding functions.BinaryOp) OverloadOpt { if len(o.ArgTypes()) != 2 { return nil, fmt.Errorf("binary function bound to non-binary overload: %s", o.ID()) } + if o.hasLateBinding { + return nil, fmt.Errorf("overload already has a late binding: %s", o.ID()) + } o.binaryOp = binding return o, nil } @@ -843,11 +881,26 @@ func FunctionBinding(binding functions.FunctionOp) OverloadOpt { if o.hasBinding() { return nil, fmt.Errorf("overload already has a binding: %s", o.ID()) } + if o.hasLateBinding { + return nil, fmt.Errorf("overload already has a late binding: %s", o.ID()) + } o.functionOp = binding return o, nil } } +// LateFunctionBinding indicates that the function has a binding which is not known at compile time. +// This is useful for functions which have side-effects or are not deterministically computable. +func LateFunctionBinding() OverloadOpt { + return func(o *OverloadDecl) (*OverloadDecl, error) { + if o.hasBinding() { + return nil, fmt.Errorf("overload already has a binding: %s", o.ID()) + } + o.hasLateBinding = true + return o, nil + } +} + // OverloadIsNonStrict enables the function to be called with error and unknown argument values. // // Note: do not use this option unless absoluately necessary as it should be an uncommon feature. diff --git a/common/decls/decls_test.go b/common/decls/decls_test.go index 262ef355d..782912bb2 100644 --- a/common/decls/decls_test.go +++ b/common/decls/decls_test.go @@ -613,6 +613,24 @@ func TestSingletonOverloadCollision(t *testing.T) { } } +func TestSingletonOverloadLateBindingCollision(t *testing.T) { + fn, err := NewFunction("id", + Overload("id_any", []*types.Type{types.AnyType}, types.AnyType, + LateFunctionBinding(), + ), + SingletonUnaryBinding(func(arg ref.Val) ref.Val { + return arg + }), + ) + if err != nil { + t.Fatalf("NewFunction() failed: %v", err) + } + _, err = fn.Bindings() + if err == nil || !strings.Contains(err.Error(), "incompatible with late bindings") { + t.Errorf("NewFunction() got %v, wanted incompatible with late bindings", err) + } +} + func TestSingletonUnaryBindingRedefinition(t *testing.T) { _, err := NewFunction("id", Overload("id_any", []*types.Type{types.AnyType}, types.AnyType), @@ -732,6 +750,74 @@ func TestOverloadFunctionBindingRedefinition(t *testing.T) { } } +func TestOverloadFunctionLateBinding(t *testing.T) { + function, err := NewFunction("id", + Overload("id_bool", []*types.Type{types.BoolType}, types.AnyType, LateFunctionBinding(), LateFunctionBinding()), + ) + if err != nil { + t.Fatalf("NewFunction() failed: %v", err) + } + if len(function.OverloadDecls()) != 1 { + t.Fatalf("NewFunction() got %v, wanted 1 overload", function.OverloadDecls()) + } + if !function.OverloadDecls()[0].HasLateBinding() { + t.Errorf("overload %v did not have a late binding", function.OverloadDecls()[0]) + } +} + +func TestOverloadFunctionMixLateAndNonLateBinding(t *testing.T) { + _, err := NewFunction("id", + Overload("id_bool", []*types.Type{types.BoolType}, types.AnyType, LateFunctionBinding()), + Overload("id_int", []*types.Type{types.IntType}, types.AnyType), + ) + if err == nil || !strings.Contains(err.Error(), "cannot mix late and non-late bindings") { + t.Errorf("NewCustomEnv() got %v, wanted cannot mix late and non-late bindings", err) + } +} + +func TestOverloadFunctionBindingWithLateBinding(t *testing.T) { + _, err := NewFunction("id", + Overload("id_bool", []*types.Type{types.BoolType}, types.AnyType, FunctionBinding(func(args ...ref.Val) ref.Val { + return args[0] + }), LateFunctionBinding()), + ) + if err == nil || !strings.Contains(err.Error(), "already has a binding") { + t.Errorf("NewCustomEnv() got %v, wanted already has a binding", err) + } +} + +func TestOverloadFunctionLateBindingWithBinding(t *testing.T) { + _, err := NewFunction("id", + Overload("id_bool", []*types.Type{types.BoolType}, types.AnyType, LateFunctionBinding(), + FunctionBinding(func(args ...ref.Val) ref.Val { + return args[0] + })), + ) + if err == nil || !strings.Contains(err.Error(), "already has a late binding") { + t.Errorf("NewCustomEnv() got %v, wanted already has a late binding", err) + } + + _, err = NewFunction("id", + Overload("id_bool", []*types.Type{types.BoolType}, types.AnyType, LateFunctionBinding(), + UnaryBinding(func(arg ref.Val) ref.Val { + return arg + })), + ) + if err == nil || !strings.Contains(err.Error(), "already has a late binding") { + t.Errorf("NewCustomEnv() got %v, wanted already has a late binding", err) + } + + _, err = NewFunction("id", + Overload("id_bool", []*types.Type{types.BoolType, types.BoolType}, types.AnyType, LateFunctionBinding(), + BinaryBinding(func(arg1 ref.Val, arg2 ref.Val) ref.Val { + return arg1 + })), + ) + if err == nil || !strings.Contains(err.Error(), "already has a late binding") { + t.Errorf("NewCustomEnv() got %v, wanted already has a late binding", err) + } +} + func TestOverloadIsNonStrict(t *testing.T) { fn, err := NewFunction("getOrDefault", MemberOverload("get", From 55657d8c2be67b463fa9812029b0ebf8c862d40e Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Mon, 14 Apr 2025 11:01:31 -0700 Subject: [PATCH 33/46] Document optional library and increase docs coverage (#1162) * Document optional library and increase docs coverage * Fix doc merge on function decl merge --- cel/decls.go | 14 +-- cel/env_test.go | 7 +- cel/library.go | 91 ++++++++++++++++--- cel/macro_test.go | 90 ++++++++++++++++++ cel/prompt.go | 2 +- cel/prompt_test.go | 6 +- cel/testdata/BUILD.bazel | 2 +- .../{basic.prompt.md => basic.prompt.txt} | 2 +- .../{macros.prompt.md => macros.prompt.txt} | 2 +- ..._env.prompt.md => standard_env.prompt.txt} | 2 +- common/decls/decls.go | 2 +- common/env/env.go | 2 +- common/env/testdata/subset_env.yaml | 1 - parser/macro.go | 13 +-- parser/macro_test.go | 72 +++++++++++++++ 15 files changed, 258 insertions(+), 50 deletions(-) create mode 100644 cel/macro_test.go rename cel/testdata/{basic.prompt.md => basic.prompt.txt} (98%) rename cel/testdata/{macros.prompt.md => macros.prompt.txt} (99%) rename cel/testdata/{standard_env.prompt.md => standard_env.prompt.txt} (99%) create mode 100644 parser/macro_test.go diff --git a/cel/decls.go b/cel/decls.go index ab5a0b8c3..4d4873bd6 100644 --- a/cel/decls.go +++ b/cel/decls.go @@ -142,10 +142,7 @@ func Constant(name string, t *Type, v ref.Val) EnvOption { // Variable creates an instance of a variable declaration with a variable name and type. func Variable(name string, t *Type) EnvOption { - return func(e *Env) (*Env, error) { - e.variables = append(e.variables, decls.NewVariable(name, t)) - return e, nil - } + return VariableWithDoc(name, t, "") } // VariableWithDoc creates an instance of a variable declaration with a variable name, type, and doc string. @@ -201,14 +198,7 @@ func Function(name string, opts ...FunctionOpt) EnvOption { if err != nil { return nil, err } - if existing, found := e.functions[fn.Name()]; found { - fn, err = existing.Merge(fn) - if err != nil { - return nil, err - } - } - e.functions[fn.Name()] = fn - return e, nil + return FunctionDecls(fn)(e) } } diff --git a/cel/env_test.go b/cel/env_test.go index 38ae5e4cd..325fad3bf 100644 --- a/cel/env_test.go +++ b/cel/env_test.go @@ -378,11 +378,14 @@ func TestEnvToConfig(t *testing.T) { name: "optional lib - alt last()", opts: []EnvOption{ OptionalTypes(), - Function("last", MemberOverload("string_last", []*Type{StringType}, StringType)), + Function("last", + FunctionDocs(`return the last value in a list, or last character in a string`), + MemberOverload("string_last", []*Type{StringType}, StringType)), }, want: env.NewConfig("optional lib - alt last()"). AddExtensions(env.NewExtension("optional", math.MaxUint32)). - AddFunctions(env.NewFunction("last", + AddFunctions(env.NewFunctionWithDoc("last", + `return the last value in a list, or last character in a string`, env.NewMemberOverload("string_last", env.NewTypeDesc("string"), []*env.TypeDesc{}, env.NewTypeDesc("string")), )), }, diff --git a/cel/library.go b/cel/library.go index fde16a019..59a10e81d 100644 --- a/cel/library.go +++ b/cel/library.go @@ -18,6 +18,7 @@ import ( "fmt" "math" + "github.com/google/cel-go/common" "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/decls" "github.com/google/cel-go/common/env" @@ -421,16 +422,29 @@ func (lib *optionalLib) CompileOptions() []EnvOption { Types(types.OptionalType), // Configure the optMap and optFlatMap macros. - Macros(ReceiverMacro(optMapMacro, 2, optMap)), + Macros(ReceiverMacro(optMapMacro, 2, optMap, + MacroDocs(`perform computation on the value if present and return the result as an optional`), + MacroExamples( + common.MultilineDescription( + `// sub with the prefix 'dev.cel' or optional.none()`, + `request.auth.tokens.?sub.optMap(id, 'dev.cel.' + id)`), + `optional.none().optMap(i, i * 2) // optional.none()`))), // Global and member functions for working with optional values. Function(optionalOfFunc, + FunctionDocs(`create a new optional_type(T) with a value where any value is considered valid`), Overload("optional_of", []*Type{paramTypeV}, optionalTypeV, + OverloadExamples(`optional.of(1) // optional(1)`), UnaryBinding(func(value ref.Val) ref.Val { return types.OptionalOf(value) }))), Function(optionalOfNonZeroValueFunc, + FunctionDocs(`create a new optional_type(T) with a value, if the value is not a zero or empty value`), Overload("optional_ofNonZeroValue", []*Type{paramTypeV}, optionalTypeV, + OverloadExamples( + `optional.ofNonZeroValue(null) // optional.none()`, + `optional.ofNonZeroValue("") // optional.none()`, + `optional.ofNonZeroValue("hello") // optional.of('hello')`), UnaryBinding(func(value ref.Val) ref.Val { v, isZeroer := value.(traits.Zeroer) if !isZeroer || !v.IsZeroValue() { @@ -439,18 +453,26 @@ func (lib *optionalLib) CompileOptions() []EnvOption { return types.OptionalNone }))), Function(optionalNoneFunc, + FunctionDocs(`singleton value representing an optional without a value`), Overload("optional_none", []*Type{}, optionalTypeV, + OverloadExamples(`optional.none()`), FunctionBinding(func(values ...ref.Val) ref.Val { return types.OptionalNone }))), Function(valueFunc, + FunctionDocs(`obtain the value contained by the optional, error if optional.none()`), MemberOverload("optional_value", []*Type{optionalTypeV}, paramTypeV, + OverloadExamples( + `optional.of(1).value() // 1`, + `optional.none().value() // error`), UnaryBinding(func(value ref.Val) ref.Val { opt := value.(*types.Optional) return opt.GetValue() }))), Function(hasValueFunc, + FunctionDocs(`determine whether the optional contains a value`), MemberOverload("optional_hasValue", []*Type{optionalTypeV}, BoolType, + OverloadExamples(`optional.of({1: 2}).hasValue() // true`), UnaryBinding(func(value ref.Val) ref.Val { opt := value.(*types.Optional) return types.Bool(opt.HasValue()) @@ -459,21 +481,43 @@ func (lib *optionalLib) CompileOptions() []EnvOption { // Implementation of 'or' and 'orValue' are special-cased to support short-circuiting in the // evaluation chain. Function("or", - MemberOverload("optional_or_optional", []*Type{optionalTypeV, optionalTypeV}, optionalTypeV)), + FunctionDocs(`chain optional expressions together, picking the first valued optional expression`), + MemberOverload("optional_or_optional", []*Type{optionalTypeV, optionalTypeV}, optionalTypeV, + OverloadExamples( + `optional.none().or(optional.of(1)) // optional.of(1)`, + common.MultilineDescription( + `// either a value from the first list, a value from the second, or optional.none()`, + `[1, 2, 3][?x].or([3, 4, 5][?y])`)))), Function("orValue", - MemberOverload("optional_orValue_value", []*Type{optionalTypeV, paramTypeV}, paramTypeV)), + FunctionDocs(`chain optional expressions together picking the first valued optional or the default value`), + MemberOverload("optional_orValue_value", []*Type{optionalTypeV, paramTypeV}, paramTypeV, + OverloadExamples( + common.MultilineDescription( + `// pick the value for the given key if the key exists, otherwise return 'you'`, + `{'hello': 'world', 'goodbye': 'cruel world'}[?greeting].orValue('you')`)))), // OptSelect is handled specially by the type-checker, so the receiver's field type is used to determine the // optput type. Function(operators.OptSelect, - Overload("select_optional_field", []*Type{DynType, StringType}, optionalTypeV)), + FunctionDocs(`if the field is present create an optional of the field value, otherwise return optional.none()`), + Overload("select_optional_field", []*Type{DynType, StringType}, optionalTypeV, + OverloadExamples( + `msg.?field // optional.of(field) if non-empty, otherwise optional.none()`, + `msg.?field.?nested_field // optional.of(nested_field) if both field and nested_field are non-empty.`))), // OptIndex is handled mostly like any other indexing operation on a list or map, so the type-checker can use // these signatures to determine type-agreement without any special handling. Function(operators.OptIndex, - Overload("list_optindex_optional_int", []*Type{listTypeV, IntType}, optionalTypeV), + FunctionDocs(`if the index is present create an optional of the field value, otherwise return optional.none()`), + Overload("list_optindex_optional_int", []*Type{listTypeV, IntType}, optionalTypeV, + OverloadExamples(`[1, 2, 3][?x] // element value if x is in the list size, else optional.none()`)), Overload("optional_list_optindex_optional_int", []*Type{OptionalType(listTypeV), IntType}, optionalTypeV), - Overload("map_optindex_optional_value", []*Type{mapTypeKV, paramTypeK}, optionalTypeV), + Overload("map_optindex_optional_value", []*Type{mapTypeKV, paramTypeK}, optionalTypeV, + OverloadExamples( + `map_value[?key] // value at the key if present, else optional.none()`, + common.MultilineDescription( + `// map key-value if index is a valid map key, else optional.none()`, + `{0: 2, 2: 4, 6: 8}[?index]`))), Overload("optional_map_optindex_optional_value", []*Type{OptionalType(mapTypeKV), paramTypeK}, optionalTypeV)), // Index overloads to accommodate using an optional value as the operand. @@ -482,45 +526,62 @@ func (lib *optionalLib) CompileOptions() []EnvOption { Overload("optional_map_index_value", []*Type{OptionalType(mapTypeKV), paramTypeK}, optionalTypeV)), } if lib.version >= 1 { - opts = append(opts, Macros(ReceiverMacro(optFlatMapMacro, 2, optFlatMap))) + opts = append(opts, Macros(ReceiverMacro(optFlatMapMacro, 2, optFlatMap, + MacroDocs(`perform computation on the value if present and produce an optional value within the computation`), + MacroExamples( + common.MultilineDescription( + `// m = {'key': {}}`, + `m.?key.optFlatMap(k, k.?subkey) // optional.none()`), + common.MultilineDescription( + `// m = {'key': {'subkey': 'value'}}`, + `m.?key.optFlatMap(k, k.?subkey) // optional.of('value')`), + )))) } if lib.version >= 2 { opts = append(opts, Function("last", + FunctionDocs(`return the last value in a list if present, otherwise optional.none()`), MemberOverload("list_last", []*Type{listTypeV}, optionalTypeV, + OverloadExamples( + `[].last() // optional.none()`, + `[1, 2, 3].last() ? optional.of(3)`), UnaryBinding(func(v ref.Val) ref.Val { list := v.(traits.Lister) - sz := list.Size().Value().(int64) - - if sz == 0 { + sz := list.Size().(types.Int) + if sz == types.IntZero { return types.OptionalNone } - return types.OptionalOf(list.Get(types.Int(sz - 1))) }), ), )) opts = append(opts, Function("first", + FunctionDocs(`return the first value in a list if present, otherwise optional.none()`), MemberOverload("list_first", []*Type{listTypeV}, optionalTypeV, + OverloadExamples( + `[].first() // optional.none()`, + `[1, 2, 3].first() ? optional.of(1)`), UnaryBinding(func(v ref.Val) ref.Val { list := v.(traits.Lister) - sz := list.Size().Value().(int64) - - if sz == 0 { + sz := list.Size().(types.Int) + if sz == types.IntZero { return types.OptionalNone } - return types.OptionalOf(list.Get(types.Int(0))) }), ), )) opts = append(opts, Function(optionalUnwrapFunc, + FunctionDocs(`convert a list of optional values to a list containing only value which are not optional.none()`), Overload("optional_unwrap", []*Type{listOptionalTypeV}, listTypeV, + OverloadExamples(`optional.unwrap([optional.of(1), optional.none()]) // [1]`), UnaryBinding(optUnwrap)))) opts = append(opts, Function(unwrapOptFunc, + FunctionDocs(`convert a list of optional values to a list containing only value which are not optional.none()`), MemberOverload("optional_unwrapOpt", []*Type{listOptionalTypeV}, listTypeV, + OverloadExamples(`[optional.of(1), optional.none()].unwrapOpt() // [1]`), UnaryBinding(optUnwrap)))) } diff --git a/cel/macro_test.go b/cel/macro_test.go new file mode 100644 index 000000000..f6400b020 --- /dev/null +++ b/cel/macro_test.go @@ -0,0 +1,90 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cel + +import ( + "testing" + + "github.com/google/cel-go/common" + "github.com/google/cel-go/common/ast" +) + +func TestGlobalVarArgMacro(t *testing.T) { + noopExpander := func(meh MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *common.Error) { + return nil, nil + } + varArgMacro := GlobalVarArgMacro("varargs", noopExpander) + if varArgMacro.ArgCount() != 0 { + t.Errorf("ArgCount() got %d, wanted 0", varArgMacro.ArgCount()) + } + if varArgMacro.Function() != "varargs" { + t.Errorf("Function() got %q, wanted 'varargs'", varArgMacro.Function()) + } + if varArgMacro.MacroKey() != "varargs:*:false" { + t.Errorf("MacroKey() got %q, wanted 'varargs:*:false'", varArgMacro.MacroKey()) + } + if varArgMacro.IsReceiverStyle() { + t.Errorf("IsReceiverStyle() got %t, wanted false", varArgMacro.IsReceiverStyle()) + } +} + +func TestReceiverVarArgMacro(t *testing.T) { + noopExpander := func(meh MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *common.Error) { + return nil, nil + } + varArgMacro := ReceiverVarArgMacro("varargs", noopExpander) + if varArgMacro.ArgCount() != 0 { + t.Errorf("ArgCount() got %d, wanted 0", varArgMacro.ArgCount()) + } + if varArgMacro.Function() != "varargs" { + t.Errorf("Function() got %q, wanted 'varargs'", varArgMacro.Function()) + } + if varArgMacro.MacroKey() != "varargs:*:true" { + t.Errorf("MacroKey() got %q, wanted 'varargs:*:true'", varArgMacro.MacroKey()) + } + if !varArgMacro.IsReceiverStyle() { + t.Errorf("IsReceiverStyle() got %t, wanted true", varArgMacro.IsReceiverStyle()) + } +} + +func TestDocumentation(t *testing.T) { + noopExpander := func(meh MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *common.Error) { + return nil, nil + } + varArgMacro := ReceiverVarArgMacro("varargs", noopExpander, + MacroDocs(`convert variable argument lists to a list literal`), + MacroExamples(`fn.varargs(1,2,3) // fn([1, 2, 3])`)) + doc, ok := varArgMacro.(common.Documentor) + if !ok { + t.Fatal("macro does not implement Documenter interface") + } + d := doc.Documentation() + if d.Kind != common.DocMacro { + t.Errorf("Documentation() got kind %v, wanted DocMacro", d.Kind) + } + if d.Name != varArgMacro.Function() { + t.Errorf("Documentation() got name %q, wanted %q", d.Name, varArgMacro.Function()) + } + if d.Description != `convert variable argument lists to a list literal` { + t.Errorf("Documentation() got description %q, wanted %q", d.Description, `convert variable argument lists to a list literal`) + } + if len(d.Children) != 1 { + t.Fatalf("macro documentation children got: %d", len(d.Children)) + } + if d.Children[0].Description != `fn.varargs(1,2,3) // fn([1, 2, 3])` { + t.Errorf("macro documentation Children[0] got %s, wanted %s", d.Children[0].Description, + `fn.varargs(1,2,3) // fn([1, 2, 3])`) + } +} diff --git a/cel/prompt.go b/cel/prompt.go index 3f4ed1e1c..929a26f91 100644 --- a/cel/prompt.go +++ b/cel/prompt.go @@ -116,7 +116,7 @@ const ( defaultPersona = `You are a software engineer with expertise in networking and application security authoring boolean Common Expression Language (CEL) expressions to ensure firewall, networking, authentication, and data access is only permitted when all conditions -are satisified.` +are satisfied.` defaultFormatRules = `Output your response as a CEL expression. diff --git a/cel/prompt_test.go b/cel/prompt_test.go index 5598cd9b8..c4d3e6358 100644 --- a/cel/prompt_test.go +++ b/cel/prompt_test.go @@ -22,13 +22,13 @@ import ( "github.com/google/cel-go/test" ) -//go:embed testdata/basic.prompt.md +//go:embed testdata/basic.prompt.txt var wantBasicPrompt string -//go:embed testdata/macros.prompt.md +//go:embed testdata/macros.prompt.txt var wantMacrosPrompt string -//go:embed testdata/standard_env.prompt.md +//go:embed testdata/standard_env.prompt.txt var wantStandardEnvPrompt string func TestPromptTemplate(t *testing.T) { diff --git a/cel/testdata/BUILD.bazel b/cel/testdata/BUILD.bazel index e643fbcae..96ca73c8c 100644 --- a/cel/testdata/BUILD.bazel +++ b/cel/testdata/BUILD.bazel @@ -19,5 +19,5 @@ genrule( filegroup( name = "prompts", - srcs = glob(["*.prompt.md"]), + srcs = glob(["*.prompt.txt"]), ) \ No newline at end of file diff --git a/cel/testdata/basic.prompt.md b/cel/testdata/basic.prompt.txt similarity index 98% rename from cel/testdata/basic.prompt.md rename to cel/testdata/basic.prompt.txt index 5f09cef8c..f396d2093 100644 --- a/cel/testdata/basic.prompt.md +++ b/cel/testdata/basic.prompt.txt @@ -1,7 +1,7 @@ You are a software engineer with expertise in networking and application security authoring boolean Common Expression Language (CEL) expressions to ensure firewall, networking, authentication, and data access is only permitted when all conditions -are satisified. +are satisfied. Output your response as a CEL expression. diff --git a/cel/testdata/macros.prompt.md b/cel/testdata/macros.prompt.txt similarity index 99% rename from cel/testdata/macros.prompt.md rename to cel/testdata/macros.prompt.txt index 9e6f0934a..91793d4c2 100644 --- a/cel/testdata/macros.prompt.md +++ b/cel/testdata/macros.prompt.txt @@ -1,7 +1,7 @@ You are a software engineer with expertise in networking and application security authoring boolean Common Expression Language (CEL) expressions to ensure firewall, networking, authentication, and data access is only permitted when all conditions -are satisified. +are satisfied. Output your response as a CEL expression. diff --git a/cel/testdata/standard_env.prompt.md b/cel/testdata/standard_env.prompt.txt similarity index 99% rename from cel/testdata/standard_env.prompt.md rename to cel/testdata/standard_env.prompt.txt index b7a3010b0..18f0fafa4 100644 --- a/cel/testdata/standard_env.prompt.md +++ b/cel/testdata/standard_env.prompt.txt @@ -1,7 +1,7 @@ You are a software engineer with expertise in networking and application security authoring boolean Common Expression Language (CEL) expressions to ensure firewall, networking, authentication, and data access is only permitted when all conditions -are satisified. +are satisfied. Output your response as a CEL expression. diff --git a/common/decls/decls.go b/common/decls/decls.go index 8a43a7eef..759b1d16b 100644 --- a/common/decls/decls.go +++ b/common/decls/decls.go @@ -165,7 +165,7 @@ func (f *FunctionDecl) Merge(other *FunctionDecl) (*FunctionDecl, error) { } // Allow for non-empty overrides of documentation if len(other.doc) != 0 && f.doc != other.doc { - f.doc = other.doc + merged.doc = other.doc } // baseline copy of the overloads and their ordinals copy(merged.overloadOrdinals, f.overloadOrdinals) diff --git a/common/env/env.go b/common/env/env.go index 4f2bebade..d848860c2 100644 --- a/common/env/env.go +++ b/common/env/env.go @@ -233,7 +233,7 @@ func NewVariable(name string, t *TypeDesc) *Variable { return NewVariableWithDoc(name, t, "") } -// NewVariable returns a serializable variable from a name, type definition, and doc string. +// NewVariableWithDoc returns a serializable variable from a name, type definition, and doc string. func NewVariableWithDoc(name string, t *TypeDesc, doc string) *Variable { return &Variable{Name: name, TypeDesc: t, Description: doc} } diff --git a/common/env/testdata/subset_env.yaml b/common/env/testdata/subset_env.yaml index 53adef486..44437e718 100644 --- a/common/env/testdata/subset_env.yaml +++ b/common/env/testdata/subset_env.yaml @@ -33,7 +33,6 @@ stdlib: variables: - name: "x" type_name: "int" - description: - name: "y" type_name: "double" - name: "z" diff --git a/parser/macro.go b/parser/macro.go index 7661755e2..1ef43c4b5 100644 --- a/parser/macro.go +++ b/parser/macro.go @@ -163,20 +163,13 @@ func (m *macro) MacroKey() string { return makeMacroKey(m.function, m.argCount, m.receiverStyle) } +// Documentation generates documentation and examples for the macro. func (m *macro) Documentation() *common.Doc { examples := make([]*common.Doc, len(m.examples)) for i, ex := range m.examples { - examples[i] = &common.Doc{ - Kind: common.DocExample, - Description: ex, - } - } - return &common.Doc{ - Kind: common.DocMacro, - Name: m.Function(), - Description: common.ParseDescription(m.doc), - Children: examples, + examples[i] = common.NewExampleDoc(ex) } + return common.NewMacroDoc(m.Function(), m.doc, examples...) } func makeMacroKey(name string, args int, receiverStyle bool) string { diff --git a/parser/macro_test.go b/parser/macro_test.go new file mode 100644 index 000000000..1d056f644 --- /dev/null +++ b/parser/macro_test.go @@ -0,0 +1,72 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package parser + +import ( + "testing" + + "github.com/google/cel-go/common" + "github.com/google/cel-go/common/ast" +) + +func TestReceiverVarArgMacro(t *testing.T) { + noopExpander := func(meh ExprHelper, target ast.Expr, args []ast.Expr) (ast.Expr, *common.Error) { + return nil, nil + } + varArgMacro := NewReceiverVarArgMacro("varargs", noopExpander, + MacroDocs(`convert variable argument lists to a list literal`), + MacroExamples(`varargs(1,2,3) // [1, 2, 3]`)) + if varArgMacro.ArgCount() != 0 { + t.Errorf("ArgCount() got %d, wanted 0", varArgMacro.ArgCount()) + } + if varArgMacro.Function() != "varargs" { + t.Errorf("Function() got %q, wanted 'varargs'", varArgMacro.Function()) + } + if varArgMacro.MacroKey() != "varargs:*:true" { + t.Errorf("MacroKey() got %q, wanted 'varargs:*:true'", varArgMacro.MacroKey()) + } + if !varArgMacro.IsReceiverStyle() { + t.Errorf("IsReceiverStyle() got %t, wanted true", varArgMacro.IsReceiverStyle()) + } +} + +func TestDocumentation(t *testing.T) { + noopExpander := func(meh ExprHelper, target ast.Expr, args []ast.Expr) (ast.Expr, *common.Error) { + return nil, nil + } + varArgMacro := NewReceiverVarArgMacro("varargs", noopExpander, + MacroDocs(`convert variable argument lists to a list literal`), + MacroExamples(`varargs(1,2,3) // [1, 2, 3]`)) + doc, ok := varArgMacro.(common.Documentor) + if !ok { + t.Fatal("macro does not implement Documenter interface") + } + d := doc.Documentation() + if d.Kind != common.DocMacro { + t.Errorf("Documentation() got kind %v, wanted DocMacro", d.Kind) + } + if d.Name != varArgMacro.Function() { + t.Errorf("Documentation() got name %q, wanted %q", d.Name, varArgMacro.Function()) + } + if d.Description != `convert variable argument lists to a list literal` { + t.Errorf("Documentation() got description %q, wanted %q", d.Description, `convert variable argument lists to a list literal`) + } + if len(d.Children) != 1 { + t.Fatalf("macro documentation children got: %d", len(d.Children)) + } + if d.Children[0].Description != `varargs(1,2,3) // [1, 2, 3]` { + t.Errorf("macro documentation Children[0] got %s, wanted %s", d.Children[0].Description, `varargs(1,2,3) // [1, 2, 3]`) + } +} From 3da6139fe4e19569f15101d2bc71ec2894a0b9b5 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Mon, 14 Apr 2025 11:01:44 -0700 Subject: [PATCH 34/46] Initial stateful observers prior to evaluation (#1163) --- cel/cel_test.go | 90 ++++++++++++++++++++++++++++++++++++ interpreter/interpretable.go | 3 ++ 2 files changed, 93 insertions(+) diff --git a/cel/cel_test.go b/cel/cel_test.go index 97e4f9f0f..cbe89c3e7 100644 --- a/cel/cel_test.go +++ b/cel/cel_test.go @@ -17,6 +17,7 @@ package cel import ( "bytes" "context" + "errors" "fmt" "os" "reflect" @@ -1766,6 +1767,95 @@ func TestEstimateCostAndRuntimeCost(t *testing.T) { } } +func TestCostLimit(t *testing.T) { + cases := []struct { + name string + expr string + decls []EnvOption + costLimit uint64 + in any + err error + }{ + { + name: "greater", + expr: `val1 > val2`, + decls: []EnvOption{ + Variable("val1", IntType), + Variable("val2", IntType), + }, + in: map[string]any{"val1": 1, "val2": 2}, + costLimit: 10, + }, + { + name: "greater - error", + expr: `val1 > val2`, + decls: []EnvOption{ + Variable("val1", IntType), + Variable("val2", IntType), + }, + in: map[string]any{"val1": 1, "val2": 2}, + costLimit: 0, + err: errors.New("actual cost limit exceeded"), + }, + } + + for _, tst := range cases { + tc := tst + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + envOpts := []EnvOption{ + CostEstimatorOptions( + checker.OverloadCostEstimate(overloads.TimestampToYear, estimateTimestampToYear), + ), + } + envOpts = append(envOpts, tc.decls...) + env := testEnv(t, envOpts...) + ast, iss := env.Compile(tc.expr) + if iss.Err() != nil { + t.Fatalf("env.Compile(%v) failed: %v", tc.expr, iss.Err()) + } + est, err := env.EstimateCost(ast, testCostEstimator{hints: map[string]uint64{}}) + if err != nil { + t.Fatalf("Env.EstimateCost(ast *Ast, estimator checker.CostEstimator) failed to estimate cost: %s\n", err) + } + + checkedAst, iss := env.Check(ast) + if iss.Err() != nil { + t.Fatalf(`Env.Check(ast *Ast) failed to check expression: %v`, iss.Err()) + } + // Evaluate expression. + program, err := env.Program(checkedAst, + CostTracking(testRuntimeCostEstimator{}), + CostTrackerOptions( + interpreter.OverloadCostTracker(overloads.TimestampToYear, trackTimestampToYear), + ), + CostLimit(tc.costLimit), + ) + if err != nil { + t.Fatalf(`Env.Program(ast *Ast, opts ...ProgramOption) failed to construct program: %v`, err) + } + _, details, err := program.Eval(tc.in) + if err != nil && tc.err == nil { + t.Fatalf(`Program.Eval(vars any) failed to evaluate expression: %v`, err) + } + actualCost := details.ActualCost() + if actualCost == nil { + t.Errorf(`EvalDetails.ActualCost() got nil for "%s" cost, wanted %d`, tc.expr, actualCost) + } + if err == nil { + if est.Min > *actualCost || est.Max < *actualCost { + t.Errorf("EvalDetails.ActualCost() failed to return a runtime cost %d is the range of estimate cost [%d, %d]", *actualCost, + est.Min, est.Max) + } + } else { + if !strings.Contains(err.Error(), tc.err.Error()) { + t.Fatalf("program.Eval() got error %v, wanted error containing %v", err, tc.err) + } + } + }) + } +} + func TestPartialVars(t *testing.T) { env := testEnv(t, Variable("x", StringType), diff --git a/interpreter/interpretable.go b/interpreter/interpretable.go index 04bbf3ffe..96b5a8ffc 100644 --- a/interpreter/interpretable.go +++ b/interpreter/interpretable.go @@ -138,6 +138,9 @@ func (oi *ObservableInterpretable) ObserveEval(vars Activation, observer func(an if err != nil { return types.WrapErr(err) } + // Provide an initial reference to the state to ensure state is available + // even in cases of interrupting errors generated during evaluation. + observer(obs.GetState(vars)) } result := oi.Interpretable.Eval(vars) // Get the state which needs to be reported back as having been observed. From 13e52967af07b62c67c877f4f4ea35bcbb1caad7 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Mon, 14 Apr 2025 15:36:03 -0700 Subject: [PATCH 35/46] Unparse Expr values to strings (#1164) * Unparse Expr values to strings * ExprToString test case --- cel/io.go | 8 +++++++- cel/io_test.go | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/cel/io.go b/cel/io.go index a327c9672..7b1a4bed2 100644 --- a/cel/io.go +++ b/cel/io.go @@ -99,7 +99,13 @@ func AstToParsedExpr(a *Ast) (*exprpb.ParsedExpr, error) { // Note, the conversion may not be an exact replica of the original expression, but will produce // a string that is semantically equivalent and whose textual representation is stable. func AstToString(a *Ast) (string, error) { - return parser.Unparse(a.NativeRep().Expr(), a.NativeRep().SourceInfo()) + return ExprToString(a.NativeRep().Expr(), a.NativeRep().SourceInfo()) +} + +// ExprToString converts an AST Expr node back to a string using macro call tracking metadata from +// source info if any macros are encountered within the expression. +func ExprToString(e ast.Expr, info *ast.SourceInfo) (string, error) { + return parser.Unparse(e, info) } // RefValueToValue converts between ref.Val and google.api.expr.v1alpha1.Value. diff --git a/cel/io_test.go b/cel/io_test.go index 7bc34ee94..d2b864495 100644 --- a/cel/io_test.go +++ b/cel/io_test.go @@ -23,6 +23,8 @@ import ( "google.golang.org/protobuf/proto" "github.com/google/cel-go/checker/decls" + celast "github.com/google/cel-go/common/ast" + "github.com/google/cel-go/common/operators" "github.com/google/cel-go/common/types" proto3pb "github.com/google/cel-go/test/proto3pb" @@ -145,6 +147,52 @@ func TestAstToString(t *testing.T) { } } +func TestExprToString(t *testing.T) { + stdEnv, err := NewEnv(EnableMacroCallTracking()) + if err != nil { + t.Fatalf("NewEnv() failed: %v", err) + } + in := "[a, b].filter(i, (i > 0) ? (-i + 4) : i)" + ast, iss := stdEnv.Parse(in) + if iss.Err() != nil { + t.Fatalf("stdEnv.Parse(%q) failed: %v", in, iss.Err()) + } + expr, err := ExprToString(ast.NativeRep().Expr(), ast.NativeRep().SourceInfo()) + if err != nil { + t.Fatalf("ExprToString(ast) failed: %v", err) + } + if expr != in { + t.Errorf("got %v, wanted %v", expr, in) + } + + // Test sub-expression unparsing. + navExpr := celast.NavigateAST(ast.NativeRep()) + condExpr := celast.MatchDescendants(navExpr, celast.FunctionMatcher(operators.Conditional))[0] + want := `(i > 0) ? (-i + 4) : i` + expr, err = ExprToString(condExpr, ast.NativeRep().SourceInfo()) + if err != nil { + t.Fatalf("ExprToString(ast) failed: %v", err) + } + if expr != want { + t.Errorf("got %v, wanted %v", expr, want) + } + + // Also passes with a nil source info, but only because the sub-expr doesn't contain macro calls. + expr, err = ExprToString(condExpr, nil) + if err != nil { + t.Fatalf("ExprToString(ast) failed: %v", err) + } + if expr != want { + t.Errorf("got %v, wanted %v", expr, want) + } + + // Fails do to missing macro information. + _, err = ExprToString(ast.NativeRep().Expr(), nil) + if err == nil { + t.Error("ExprToString() succeeded, wanted error") + } +} + func TestAstToStringNil(t *testing.T) { expr, err := AstToString(nil) if err == nil || !strings.Contains(err.Error(), "unsupported expr") { From 2aa9572df845e4d1e0d9e6071fd5ddf927513c11 Mon Sep 17 00:00:00 2001 From: "zev.ac" Date: Wed, 16 Apr 2025 00:45:10 +0530 Subject: [PATCH 36/46] Add test runner library (#1149) Test runner library for compiling and executing CEL unit tests with Bazel. --- WORKSPACE | 6 +- conformance/go.mod | 2 +- conformance/go.sum | 4 +- go.mod | 2 +- go.sum | 4 +- policy/BUILD.bazel | 7 +- policy/compiler_test.go | 69 +- policy/conformance.go | 2 + policy/go.mod | 2 +- policy/go.sum | 4 +- policy/helper_test.go | 5 +- policy/testdata/context_pb/config.yaml | 2 - policy/testdata/context_pb/tests.yaml | 21 +- policy/testdata/k8s/config.yaml | 4 - policy/testdata/k8s/tests.yaml | 3 +- policy/testdata/limits/tests.yaml | 12 +- policy/testdata/nested_rule/tests.yaml | 9 +- policy/testdata/nested_rule2/tests.yaml | 12 +- policy/testdata/nested_rule3/tests.yaml | 12 +- policy/testdata/nested_rule4/tests.yaml | 6 +- policy/testdata/nested_rule5/tests.yaml | 12 +- policy/testdata/nested_rule6/tests.yaml | 3 +- policy/testdata/nested_rule7/tests.yaml | 12 +- policy/testdata/pb/tests.yaml | 7 +- policy/testdata/required_labels/config.yaml | 1 - policy/testdata/required_labels/tests.yaml | 15 +- .../restricted_destinations/base_config.yaml | 44 +- .../restricted_destinations/tests.yaml | 12 +- policy/testdata/unnest/tests.yaml | 18 +- test/BUILD.bazel | 3 + test/suite.go | 60 ++ tools/celtest/BUILD.bazel | 77 ++ tools/celtest/test_runner.go | 707 ++++++++++++++++++ tools/celtest/test_runner_test.go | 196 +++++ .../testdata/config.yaml} | 10 +- .../testdata/custom_policy.celpolicy | 1 + .../celtest/testdata/custom_policy_tests.yaml | 42 ++ tools/celtest/testdata/raw_expr.cel | 1 + tools/celtest/testdata/raw_expr_tests.yaml | 34 + tools/compiler/BUILD.bazel | 5 +- tools/compiler/compiler.go | 76 +- tools/compiler/compiler_test.go | 79 +- tools/compiler/testdata/config.yaml | 2 - tools/go.mod | 3 +- tools/go.sum | 4 +- vendor/cel.dev/expr/MODULE.bazel | 10 +- vendor/modules.txt | 4 +- 47 files changed, 1369 insertions(+), 257 deletions(-) create mode 100644 test/suite.go create mode 100644 tools/celtest/BUILD.bazel create mode 100644 tools/celtest/test_runner.go create mode 100644 tools/celtest/test_runner_test.go rename tools/{compiler/testdata/custom_policy_config.yaml => celtest/testdata/config.yaml} (83%) rename tools/{compiler => celtest}/testdata/custom_policy.celpolicy (97%) create mode 100644 tools/celtest/testdata/custom_policy_tests.yaml create mode 100644 tools/celtest/testdata/raw_expr.cel create mode 100644 tools/celtest/testdata/raw_expr_tests.yaml diff --git a/WORKSPACE b/WORKSPACE index f566d7d09..956f16a4d 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -101,8 +101,8 @@ go_repository( go_repository( name = "dev_cel_expr", importpath = "cel.dev/expr", - sum = "h1:xoFEsNh972Yzey8N9TCPx2nDvMN7TMhQEzxLuj/iRrI=", - version = "v0.22.1", + sum = "h1:K4KOtPCJQjVggkARsjG9RWXP6O4R73aHeJMa/dmCQQg=", + version = "v0.23.1", ) # local_repository( @@ -153,7 +153,7 @@ go_repository( # of the above repositories but at different versions, so ours must come first. go_rules_dependencies() -go_register_toolchains(version = "1.21.1") +go_register_toolchains(version = "1.22.0") gazelle_dependencies() diff --git a/conformance/go.mod b/conformance/go.mod index 115630be9..92d2ca476 100644 --- a/conformance/go.mod +++ b/conformance/go.mod @@ -3,7 +3,7 @@ module github.com/google/cel-go/conformance go 1.22.0 require ( - cel.dev/expr v0.22.1 + cel.dev/expr v0.23.1 github.com/bazelbuild/rules_go v0.49.0 github.com/google/cel-go v0.21.0 github.com/google/go-cmp v0.6.0 diff --git a/conformance/go.sum b/conformance/go.sum index 9544b17a9..95ac73a52 100644 --- a/conformance/go.sum +++ b/conformance/go.sum @@ -1,5 +1,5 @@ -cel.dev/expr v0.22.1 h1:xoFEsNh972Yzey8N9TCPx2nDvMN7TMhQEzxLuj/iRrI= -cel.dev/expr v0.22.1/go.mod h1:MrpN08Q+lEBs+bGYdLxxHkZoUSsCp0nSKTs0nTymJgw= +cel.dev/expr v0.23.1 h1:K4KOtPCJQjVggkARsjG9RWXP6O4R73aHeJMa/dmCQQg= +cel.dev/expr v0.23.1/go.mod h1:hLPLo1W4QUmuYdA72RBX06QTs6MXw941piREPl3Yfiw= github.com/antlr4-go/antlr/v4 v4.13.0 h1:lxCg3LAv+EUK6t1i0y1V6/SLeUi0eKEKdhQAlS8TVTI= github.com/antlr4-go/antlr/v4 v4.13.0/go.mod h1:pfChB/xh/Unjila75QW7+VU4TSnWnnk9UTnmpPaOR2g= github.com/bazelbuild/rules_go v0.49.0 h1:5vCbuvy8Q11g41lseGJDc5vxhDjJtfxr6nM/IC4VmqM= diff --git a/go.mod b/go.mod index 9f089f4fd..914c1ec28 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.22.0 toolchain go1.23.0 require ( - cel.dev/expr v0.22.1 + cel.dev/expr v0.23.1 github.com/antlr4-go/antlr/v4 v4.13.0 github.com/stoewer/go-strcase v1.2.0 google.golang.org/genproto/googleapis/api v0.0.0-20240826202546-f6391c0de4c7 diff --git a/go.sum b/go.sum index 062b316c3..23fe2170a 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -cel.dev/expr v0.22.1 h1:xoFEsNh972Yzey8N9TCPx2nDvMN7TMhQEzxLuj/iRrI= -cel.dev/expr v0.22.1/go.mod h1:MrpN08Q+lEBs+bGYdLxxHkZoUSsCp0nSKTs0nTymJgw= +cel.dev/expr v0.23.1 h1:K4KOtPCJQjVggkARsjG9RWXP6O4R73aHeJMa/dmCQQg= +cel.dev/expr v0.23.1/go.mod h1:hLPLo1W4QUmuYdA72RBX06QTs6MXw941piREPl3Yfiw= github.com/antlr4-go/antlr/v4 v4.13.0 h1:lxCg3LAv+EUK6t1i0y1V6/SLeUi0eKEKdhQAlS8TVTI= github.com/antlr4-go/antlr/v4 v4.13.0/go.mod h1:pfChB/xh/Unjila75QW7+VU4TSnWnnk9UTnmpPaOR2g= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= diff --git a/policy/BUILD.bazel b/policy/BUILD.bazel index f058f1ab9..15facc55d 100644 --- a/policy/BUILD.bazel +++ b/policy/BUILD.bazel @@ -63,6 +63,7 @@ go_test( embed = [":go_default_library"], deps = [ "//cel:go_default_library", + "//test:go_default_library", "//common/types:go_default_library", "//interpreter:go_default_library", "//common/types/ref:go_default_library", @@ -72,6 +73,6 @@ go_test( ) filegroup( - name = "k8s_policy_testdata", - srcs = glob(["testdata/k8s/*"]), -) \ No newline at end of file + name = "testdata", + srcs = glob(["testdata/**"]), +) diff --git a/policy/compiler_test.go b/policy/compiler_test.go index 545b01f8b..4bd1bfcce 100644 --- a/policy/compiler_test.go +++ b/policy/compiler_test.go @@ -243,24 +243,25 @@ func (r *runner) run(t *testing.T) { input := map[string]any{} var err error var activation interpreter.Activation - for k, v := range tc.Input { - if v.Expr != "" { - input[k] = r.eval(t, v.Expr) - continue + if tc.InputContext != nil && tc.InputContext.ContextExpr != "" { + ctxExpr := tc.InputContext.ContextExpr + ctx, err := r.eval(t, ctxExpr).ConvertToNative( + reflect.TypeOf(((*proto.Message)(nil))).Elem()) + if err != nil { + t.Fatalf("context variable is not a valid proto: %v", err) } - if v.ContextExpr != "" { - ctx, err := r.eval(t, v.ContextExpr).ConvertToNative( - reflect.TypeOf(((*proto.Message)(nil))).Elem()) - if err != nil { - t.Fatalf("context variable is not a valid proto: %v", err) - } - activation, err = cel.ContextProtoVars(ctx.(proto.Message)) - if err != nil { - t.Fatalf("cel.ContextProtoVars() failed: %v", err) + activation, err = cel.ContextProtoVars(ctx.(proto.Message)) + if err != nil { + t.Fatalf("cel.ContextProtoVars() failed: %v", err) + } + } else if len(tc.Input) != 0 { + for k, v := range tc.Input { + if v.Expr != "" { + input[k] = r.eval(t, v.Expr) + continue } - break + input[k] = v.Value } - input[k] = v.Value } if activation == nil { activation, err = interpreter.NewActivation(input) @@ -272,7 +273,12 @@ func (r *runner) run(t *testing.T) { if err != nil { t.Fatalf("prg.Eval(input) failed: %v", err) } - testOut := r.eval(t, tc.Output) + var testOut ref.Val + if tc.Output.Expr != "" { + testOut = r.eval(t, tc.Output.Expr) + } else if tc.Output.Value != nil { + testOut = r.env.CELTypeAdapter().NativeToValue(tc.Output.Value) + } if optOut, ok := out.(*types.Optional); ok { if optOut.Equal(types.OptionalNone) == types.True { if testOut.Equal(types.OptionalNone) != types.True { @@ -299,24 +305,25 @@ func (r *runner) bench(b *testing.B) { input := map[string]any{} var err error var activation interpreter.Activation - for k, v := range tc.Input { - if v.Expr != "" { - input[k] = r.eval(b, v.Expr) - continue + if tc.InputContext != nil && tc.InputContext.ContextExpr != "" { + ctxExpr := tc.InputContext.ContextExpr + ctx, err := r.eval(b, ctxExpr).ConvertToNative( + reflect.TypeOf(((*proto.Message)(nil))).Elem()) + if err != nil { + b.Fatalf("context variable is not a valid proto: %v", err) } - if v.ContextExpr != "" { - ctx, err := r.eval(b, v.ContextExpr).ConvertToNative( - reflect.TypeOf(((*proto.Message)(nil))).Elem()) - if err != nil { - b.Fatalf("context variable is not a valid proto: %v", err) - } - activation, err = cel.ContextProtoVars(ctx.(proto.Message)) - if err != nil { - b.Fatalf("cel.ContextProtoVars() failed: %v", err) + activation, err = cel.ContextProtoVars(ctx.(proto.Message)) + if err != nil { + b.Fatalf("cel.ContextProtoVars() failed: %v", err) + } + } else if tc.Input != nil { + for k, v := range tc.Input { + if v.Expr != "" { + input[k] = r.eval(b, v.Expr) + continue } - break + input[k] = v.Value } - input[k] = v.Value } if activation == nil { activation, err = interpreter.NewActivation(input) diff --git a/policy/conformance.go b/policy/conformance.go index 3d05f411c..160c5c87e 100644 --- a/policy/conformance.go +++ b/policy/conformance.go @@ -15,6 +15,8 @@ package policy // TestSuite describes a set of tests divided by section. +// +// Deprecated: Use google3/third_party/cel/go/test/suite.go instead. type TestSuite struct { Description string `yaml:"description"` Sections []*TestSection `yaml:"section"` diff --git a/policy/go.mod b/policy/go.mod index 410781f5d..400a3cac5 100644 --- a/policy/go.mod +++ b/policy/go.mod @@ -9,7 +9,7 @@ require ( ) require ( - cel.dev/expr v0.22.1 // indirect + cel.dev/expr v0.23.1 // indirect github.com/antlr4-go/antlr/v4 v4.13.1 // indirect github.com/stoewer/go-strcase v1.3.0 // indirect golang.org/x/exp v0.0.0-20240823005443-9b4947da3948 // indirect diff --git a/policy/go.sum b/policy/go.sum index 8b4ac4221..35ef4fed1 100644 --- a/policy/go.sum +++ b/policy/go.sum @@ -1,5 +1,5 @@ -cel.dev/expr v0.22.1 h1:xoFEsNh972Yzey8N9TCPx2nDvMN7TMhQEzxLuj/iRrI= -cel.dev/expr v0.22.1/go.mod h1:MrpN08Q+lEBs+bGYdLxxHkZoUSsCp0nSKTs0nTymJgw= +cel.dev/expr v0.23.1 h1:K4KOtPCJQjVggkARsjG9RWXP6O4R73aHeJMa/dmCQQg= +cel.dev/expr v0.23.1/go.mod h1:hLPLo1W4QUmuYdA72RBX06QTs6MXw941piREPl3Yfiw= github.com/antlr4-go/antlr/v4 v4.13.1 h1:SqQKkuVZ+zWkMMNkjy5FZe5mr5WURWnlpmOuzYWrPrQ= github.com/antlr4-go/antlr/v4 v4.13.1/go.mod h1:GKmUxMtwp6ZgGwZSva4eWPC5mS6vUAmOABFgjdkM7Nw= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/policy/helper_test.go b/policy/helper_test.go index 8e117331c..98ca4f7ab 100644 --- a/policy/helper_test.go +++ b/policy/helper_test.go @@ -23,6 +23,7 @@ import ( "github.com/google/cel-go/common/env" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" + "github.com/google/cel-go/test" "gopkg.in/yaml.v3" @@ -447,13 +448,13 @@ func readPolicyConfig(t testing.TB, fileName string) *env.Config { return config } -func readTestSuite(t testing.TB, fileName string) *TestSuite { +func readTestSuite(t testing.TB, fileName string) *test.Suite { t.Helper() testCaseBytes, err := os.ReadFile(fileName) if err != nil { t.Fatalf("os.ReadFile(%s) failed: %v", fileName, err) } - suite := &TestSuite{} + suite := &test.Suite{} err = yaml.Unmarshal(testCaseBytes, suite) if err != nil { log.Fatalf("yaml.Unmarshal(%s) error: %v", fileName, err) diff --git a/policy/testdata/context_pb/config.yaml b/policy/testdata/context_pb/config.yaml index e7804575c..53ea95425 100644 --- a/policy/testdata/context_pb/config.yaml +++ b/policy/testdata/context_pb/config.yaml @@ -15,8 +15,6 @@ name: "context_pb" container: "google.expr.proto3" extensions: - - name: "optional" - version: "latest" - name: "strings" version: 2 context_variable: diff --git a/policy/testdata/context_pb/tests.yaml b/policy/testdata/context_pb/tests.yaml index 11e377e53..80849783b 100644 --- a/policy/testdata/context_pb/tests.yaml +++ b/policy/testdata/context_pb/tests.yaml @@ -16,18 +16,13 @@ description: "Protobuf input tests" section: - name: "valid" tests: - - name: "good spec" - input: - spec: - context_expr: > - test.TestAllTypes{single_int32: 10} - output: "optional.none()" + - name: "good spec" + context_expr: "test.TestAllTypes{single_int32: 10}" + output: + expr: "optional.none()" - name: "invalid" tests: - - name: "bad spec" - input: - spec: - context_expr: > - test.TestAllTypes{single_int32: 11} - output: > - "invalid spec, got single_int32=11, wanted <= 10" + - name: "bad spec" + context_expr: "test.TestAllTypes{single_int32: 11}" + output: + value: "invalid spec, got single_int32=11, wanted <= 10" diff --git a/policy/testdata/k8s/config.yaml b/policy/testdata/k8s/config.yaml index 5a2cb3290..15a32b535 100644 --- a/policy/testdata/k8s/config.yaml +++ b/policy/testdata/k8s/config.yaml @@ -14,10 +14,6 @@ name: k8s extensions: - - name: "optional" - version: "latest" - - name: "bindings" - version: "latest" - name: "strings" version: 2 variables: diff --git a/policy/testdata/k8s/tests.yaml b/policy/testdata/k8s/tests.yaml index 3965ea0f9..f3e7de790 100644 --- a/policy/testdata/k8s/tests.yaml +++ b/policy/testdata/k8s/tests.yaml @@ -28,4 +28,5 @@ section: - staging.dev.cel.container1 - staging.dev.cel.container2 - preprod.dev.cel.container3 - output: "'only staging containers are allowed in namespace dev.cel'" + output: + value: "only staging containers are allowed in namespace dev.cel" diff --git a/policy/testdata/limits/tests.yaml b/policy/testdata/limits/tests.yaml index 8f50a519d..88772e075 100644 --- a/policy/testdata/limits/tests.yaml +++ b/policy/testdata/limits/tests.yaml @@ -20,19 +20,23 @@ section: input: now: expr: "timestamp('2024-07-30T00:30:00Z')" - output: "'hello, me'" + output: + value: "hello, me" - name: "8pm" input: now: expr: "timestamp('2024-07-30T20:30:00Z')" - output: "'goodbye, me!'" + output: + value: "goodbye, me!" - name: "9pm" input: now: expr: "timestamp('2024-07-30T21:30:00Z')" - output: "'goodbye, me!!'" + output: + value: "goodbye, me!!" - name: "11pm" input: now: expr: "timestamp('2024-07-30T23:30:00Z')" - output: "'goodbye, me!!!'" + output: + value: "goodbye, me!!!" diff --git a/policy/testdata/nested_rule/tests.yaml b/policy/testdata/nested_rule/tests.yaml index 48101c89a..3f9f63437 100644 --- a/policy/testdata/nested_rule/tests.yaml +++ b/policy/testdata/nested_rule/tests.yaml @@ -21,13 +21,15 @@ section: resource: value: origin: "ir" - output: "{'banned': true}" + output: + expr: "{'banned': true}" - name: "by_default" input: resource: value: origin: "de" - output: "{'banned': true}" + output: + expr: "{'banned': true}" - name: "permitted" tests: - name: "valid_origin" @@ -35,4 +37,5 @@ section: resource: value: origin: "uk" - output: "{'banned': false}" + output: + expr: "{'banned': false}" diff --git a/policy/testdata/nested_rule2/tests.yaml b/policy/testdata/nested_rule2/tests.yaml index ac725956c..0e1a9ca69 100644 --- a/policy/testdata/nested_rule2/tests.yaml +++ b/policy/testdata/nested_rule2/tests.yaml @@ -22,21 +22,24 @@ section: value: user: "bad-user" origin: "ir" - output: "{'banned': 'restricted_region'}" + output: + expr: "{'banned': 'restricted_region'}" - name: "by_default" input: resource: value: user: "bad-user" origin: "de" - output: "{'banned': 'bad_actor'}" + output: + expr: "{'banned': 'bad_actor'}" - name: "unconfigured_region" input: resource: value: user: "good-user" origin: "de" - output: "{'banned': 'unconfigured_region'}" + output: + expr: "{'banned': 'unconfigured_region'}" - name: "permitted" tests: - name: "valid_origin" @@ -45,4 +48,5 @@ section: value: user: "good-user" origin: "uk" - output: "{}" + output: + expr: "{}" diff --git a/policy/testdata/nested_rule3/tests.yaml b/policy/testdata/nested_rule3/tests.yaml index ece86eba0..9d993c65f 100644 --- a/policy/testdata/nested_rule3/tests.yaml +++ b/policy/testdata/nested_rule3/tests.yaml @@ -22,21 +22,24 @@ section: value: user: "bad-user" origin: "ir" - output: "{'banned': 'restricted_region'}" + output: + expr: "{'banned': 'restricted_region'}" - name: "by_default" input: resource: value: user: "bad-user" origin: "de" - output: "{'banned': 'bad_actor'}" + output: + expr: "{'banned': 'bad_actor'}" - name: "unconfigured_region" input: resource: value: user: "good-user" origin: "de" - output: "{'banned': 'unconfigured_region'}" + output: + expr: "{'banned': 'unconfigured_region'}" - name: "permitted" tests: - name: "valid_origin" @@ -45,4 +48,5 @@ section: value: user: "good-user" origin: "uk" - output: "optional.none()" + output: + expr: "optional.none()" diff --git a/policy/testdata/nested_rule4/tests.yaml b/policy/testdata/nested_rule4/tests.yaml index a5af137f3..006eddb88 100644 --- a/policy/testdata/nested_rule4/tests.yaml +++ b/policy/testdata/nested_rule4/tests.yaml @@ -20,9 +20,11 @@ section: input: x: value: 0 - output: "false" + output: + value: false - name: "x=2" input: x: value: 2 - output: "true" + output: + value: true diff --git a/policy/testdata/nested_rule5/tests.yaml b/policy/testdata/nested_rule5/tests.yaml index 66cc44507..8cd794051 100644 --- a/policy/testdata/nested_rule5/tests.yaml +++ b/policy/testdata/nested_rule5/tests.yaml @@ -20,19 +20,23 @@ section: input: x: value: 0 - output: "false" + output: + value: false - name: "x=1" input: x: value: 1 - output: "optional.none()" + output: + expr: "optional.none()" - name: "x=2" input: x: value: 2 - output: "optional.none()" + output: + expr: "optional.none()" - name: "x=3" input: x: value: 3 - output: "true" + output: + value: true diff --git a/policy/testdata/nested_rule6/tests.yaml b/policy/testdata/nested_rule6/tests.yaml index dabce623c..fef586df0 100644 --- a/policy/testdata/nested_rule6/tests.yaml +++ b/policy/testdata/nested_rule6/tests.yaml @@ -20,4 +20,5 @@ section: input: x: value: 0 - output: "false" + output: + value: false diff --git a/policy/testdata/nested_rule7/tests.yaml b/policy/testdata/nested_rule7/tests.yaml index 7844e18f6..f740c7639 100644 --- a/policy/testdata/nested_rule7/tests.yaml +++ b/policy/testdata/nested_rule7/tests.yaml @@ -20,19 +20,23 @@ section: input: x: value: 1 - output: "optional.none()" + output: + expr: "optional.none()" - name: "x=2" input: x: value: 2 - output: "false" + output: + value: false - name: "x=3" input: x: value: 3 - output: "true" + output: + value: true - name: "x=4" input: x: value: 4 - output: "true" + output: + value: true diff --git a/policy/testdata/pb/tests.yaml b/policy/testdata/pb/tests.yaml index 770bcad09..a39f7b73f 100644 --- a/policy/testdata/pb/tests.yaml +++ b/policy/testdata/pb/tests.yaml @@ -21,7 +21,8 @@ section: spec: expr: > test.TestAllTypes{single_int32: 10} - output: "optional.none()" + output: + expr: "optional.none()" - name: "invalid" tests: - name: "bad spec" @@ -29,5 +30,5 @@ section: spec: expr: > test.TestAllTypes{single_int32: 11} - output: > - "invalid spec, got single_int32=11, wanted <= 10" + output: + value: "invalid spec, got single_int32=11, wanted <= 10" diff --git a/policy/testdata/required_labels/config.yaml b/policy/testdata/required_labels/config.yaml index f9081478a..c5c612e20 100644 --- a/policy/testdata/required_labels/config.yaml +++ b/policy/testdata/required_labels/config.yaml @@ -14,7 +14,6 @@ name: "labels" extensions: - - name: "bindings" - name: "strings" version: 2 - name: "two-var-comprehensions" diff --git a/policy/testdata/required_labels/tests.yaml b/policy/testdata/required_labels/tests.yaml index a4bf96dc2..2159b1b24 100644 --- a/policy/testdata/required_labels/tests.yaml +++ b/policy/testdata/required_labels/tests.yaml @@ -29,7 +29,8 @@ section: env: prod experiment: "group b" release: "v0.1.0" - output: "optional.none()" + output: + expr: "optional.none()" - name: "missing" tests: - name: "env" @@ -44,8 +45,8 @@ section: labels: experiment: "group b" release: "v0.1.0" - output: > - "missing one or more required labels: [\"env\"]" + output: + value: "missing one or more required labels: [\"env\"]" - name: "experiment" input: spec: @@ -58,8 +59,8 @@ section: labels: env: staging release: "v0.1.0" - output: > - "missing one or more required labels: [\"experiment\"]" + output: + value: "missing one or more required labels: [\"experiment\"]" - name: "invalid" tests: - name: "env" @@ -75,5 +76,5 @@ section: env: staging experiment: "group b" release: "v0.1.0" - output: > - "invalid values provided on one or more labels: [\"env\"]" + output: + value: "invalid values provided on one or more labels: [\"env\"]" diff --git a/policy/testdata/restricted_destinations/base_config.yaml b/policy/testdata/restricted_destinations/base_config.yaml index 2aae385ca..615a8b915 100644 --- a/policy/testdata/restricted_destinations/base_config.yaml +++ b/policy/testdata/restricted_destinations/base_config.yaml @@ -14,26 +14,26 @@ name: "labels" extensions: -- name: "lists" -- name: "sets" + - name: "lists" + - name: "sets" variables: -- name: "destination.ip" - type_name: "string" -- name: "origin.ip" - type_name: "string" -- name: "spec.restricted_destinations" - type_name: "list" - params: - - type_name: "string" -- name: "spec.origin" - type_name: "string" -- name: "request" - type_name: "map" - params: - - type_name: "string" - - type_name: "dyn" -- name: "resource" - type_name: "map" - params: - - type_name: "string" - - type_name: "dyn" + - name: "destination.ip" + type_name: "string" + - name: "origin.ip" + type_name: "string" + - name: "spec.restricted_destinations" + type_name: "list" + params: + - type_name: "string" + - name: "spec.origin" + type_name: "string" + - name: "request" + type_name: "map" + params: + - type_name: "string" + - type_name: "dyn" + - name: "resource" + type_name: "map" + params: + - type_name: "string" + - type_name: "dyn" diff --git a/policy/testdata/restricted_destinations/tests.yaml b/policy/testdata/restricted_destinations/tests.yaml index 1cf59fe62..e448fb1a9 100644 --- a/policy/testdata/restricted_destinations/tests.yaml +++ b/policy/testdata/restricted_destinations/tests.yaml @@ -40,7 +40,8 @@ section: name: "/company/acme/secrets/doomsday-device" labels: location: "us" - output: "false" # false means unrestricted + output: + value: false # false means unrestricted - name: "nationality_allowed" input: "spec.origin": @@ -64,7 +65,8 @@ section: name: "/company/acme/secrets/doomsday-device" labels: location: "us" - output: "false" + output: + value: false - name: "invalid" tests: - name: "destination_ip_prohibited" @@ -91,7 +93,8 @@ section: name: "/company/acme/secrets/doomsday-device" labels: location: "us" - output: "true" # true means restricted + output: + value: true # true means restricted - name: "resource_nationality_prohibited" input: "spec.origin": @@ -115,4 +118,5 @@ section: name: "/company/acme/secrets/doomsday-device" labels: location: "cu" - output: "true" + output: + value: true diff --git a/policy/testdata/unnest/tests.yaml b/policy/testdata/unnest/tests.yaml index 9bed7b352..31a8770d7 100644 --- a/policy/testdata/unnest/tests.yaml +++ b/policy/testdata/unnest/tests.yaml @@ -20,31 +20,33 @@ section: input: values: expr: "[4, 6]" - output: > - "some divisible by 2" + output: + value: "some divisible by 2" - name: "false" input: values: expr: "[1, 3, 5]" - output: "optional.none()" + output: + expr: "optional.none()" - name: "empty-set" input: values: expr: "[1, 2]" - output: "optional.none()" + output: + expr: "optional.none()" - name: "divisible by 4" tests: - name: "true" input: values: expr: "[4, 7]" - output: > - "at least one divisible by 4" + output: + value: "at least one divisible by 4" - name: "power of 6" tests: - name: "true" input: values: expr: "[6, 7]" - output: > - "at least one power of 6" + output: + value: "at least one power of 6" diff --git a/test/BUILD.bazel b/test/BUILD.bazel index 59bd9a3dc..37b093a2a 100644 --- a/test/BUILD.bazel +++ b/test/BUILD.bazel @@ -9,6 +9,8 @@ package( "//interpreter:__subpackages__", "//parser:__subpackages__", "//server:__subpackages__", + "//tools:__subpackages__", + "//policy:__subpackages__", ], licenses = ["notice"], # Apache 2.0 ) @@ -18,6 +20,7 @@ go_library( srcs = [ "compare.go", "expr.go", + "suite.go", ], importpath = "github.com/google/cel-go/test", deps = [ diff --git a/test/suite.go b/test/suite.go new file mode 100644 index 000000000..2b499e45d --- /dev/null +++ b/test/suite.go @@ -0,0 +1,60 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +// Suite is a collection of tests designed to evaluate the correctness of +// a CEL policy or a CEL expression +type Suite struct { + Name string `yaml:"name"` + Description string `yaml:"description"` + Sections []*Section `yaml:"section"` +} + +// Section is a collection of related test cases. +type Section struct { + Name string `yaml:"name"` + Tests []*Case `yaml:"tests"` +} + +// Case is a test case to validate a CEL policy or expression. The test case +// encompasses evaluation of the compiled expression using the provided input +// bindings and asserting the result against the expected result. +type Case struct { + Name string `yaml:"name"` + Description string `yaml:"description"` + Input map[string]*InputValue `yaml:"input,omitempty"` + *InputContext `yaml:",inline,omitempty"` + Output *Output `yaml:"output"` +} + +// InputContext represents an optional context expression. +type InputContext struct { + ContextExpr string `yaml:"context_expr"` +} + +// InputValue represents an input value for a binding which can be either a simple literal value or +// an expression. +type InputValue struct { + Value any `yaml:"value"` + Expr string `yaml:"expr"` +} + +// Output represents the expected result of a test case. +type Output struct { + Value any `yaml:"value"` + Expr string `yaml:"expr"` + ErrorSet []string `yaml:"error_set"` + UnknownSet []int64 `yaml:"unknown_set"` +} diff --git a/tools/celtest/BUILD.bazel b/tools/celtest/BUILD.bazel new file mode 100644 index 000000000..295f4e755 --- /dev/null +++ b/tools/celtest/BUILD.bazel @@ -0,0 +1,77 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], +) + +go_library( + name = "go_default_library", + srcs = [ + "test_runner.go", + ], + importpath = "github.com/google/cel-go/tools/celtest", + deps = [ + "//cel:go_default_library", + "//common/types:go_default_library", + "//common/types/ref:go_default_library", + "//interpreter:go_default_library", + "//test:go_default_library", + "//tools/compiler:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", + "@dev_cel_expr//:expr", + "@dev_cel_expr//conformance/test:go_default_library", + "@in_gopkg_yaml_v3//:go_default_library", + "@io_bazel_rules_go//go/runfiles", + "@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library", + "@org_golang_google_protobuf//encoding/prototext:go_default_library", + "@org_golang_google_protobuf//proto:go_default_library", + "@org_golang_google_protobuf//reflect/protodesc:go_default_library", + "@org_golang_google_protobuf//reflect/protoreflect:go_default_library", + "@org_golang_google_protobuf//reflect/protoregistry:go_default_library", + "@org_golang_google_protobuf//testing/protocmp:go_default_library", + "@org_golang_google_protobuf//types/descriptorpb:go_default_library", + "@org_golang_google_protobuf//types/dynamicpb:go_default_library", + ], +) + +go_test( + name = "go_default_test", + size = "small", + srcs = [ + "test_runner_test.go", + ], + data = [ + ":testdata", + "//policy:testdata", + ], + embed = [":go_default_library"], + deps = [ + "//cel:go_default_library", + "//common/decls:go_default_library", + "//common/types:go_default_library", + "//common/types/ref:go_default_library", + "//policy:go_default_library", + "//tools/compiler:go_default_library", + "@in_gopkg_yaml_v3//:go_default_library", + ] +) + +filegroup( + name = "testdata", + srcs = glob(["testdata/**"]), +) \ No newline at end of file diff --git a/tools/celtest/test_runner.go b/tools/celtest/test_runner.go new file mode 100644 index 000000000..39f29dea3 --- /dev/null +++ b/tools/celtest/test_runner.go @@ -0,0 +1,707 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package celtest provides functions for testing CEL policies and expressions. +package celtest + +import ( + "flag" + "fmt" + "os" + "reflect" + "strings" + "testing" + + "gopkg.in/yaml.v3" + + "github.com/google/cel-go/cel" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" + "github.com/google/cel-go/interpreter" + "github.com/google/cel-go/test" + "github.com/google/cel-go/tools/compiler" + "github.com/google/go-cmp/cmp" + "google.golang.org/protobuf/encoding/prototext" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protodesc" + "google.golang.org/protobuf/reflect/protoregistry" + "google.golang.org/protobuf/testing/protocmp" + + celpb "cel.dev/expr" + conformancepb "cel.dev/expr/conformance/test" + exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" + descpb "google.golang.org/protobuf/types/descriptorpb" + dynamicpb "google.golang.org/protobuf/types/dynamicpb" +) + +var ( + celExpression string + testSuitePath string + fileDescriptorSetPath string + configPath string + baseConfigPath string +) + +func init() { + flag.StringVar(&testSuitePath, "test_suite_path", "", "path to a test suite") + flag.StringVar(&fileDescriptorSetPath, "file_descriptor_set", "", "path to a file descriptor set") + flag.StringVar(&configPath, "config_path", "", "path to a config file") + flag.StringVar(&baseConfigPath, "base_config_path", "", "path to a base config file") + flag.StringVar(&celExpression, "cel_expr", "", "CEL expression to test") + flag.Parse() +} + +// TestRunnerOption is used to configure the following attributes of the Test Runner: +// - set the Compiler +// - add Input Expressions +// - set the test suite file path +// - set the test suite parser based on the file format: YAML or Textproto +type TestRunnerOption func(*TestRunner) (*TestRunner, error) + +// TriggerTests triggers tests for a CEL policy, expression or checked expression +// with the provided set of options. The options can be used to: +// - configure the Compiler used for parsing and compiling the expression +// - configure the Test Runner used for parsing and executing the tests +func TriggerTests(t *testing.T, testRunnerOpts []TestRunnerOption, testCompilerOpts ...any) { + testRunnerOptions := testRunnerOptions(testRunnerOpts, testCompilerOpts...) + tr, err := NewTestRunner(testRunnerOptions...) + if err != nil { + t.Fatalf("error creating test runner: %v", err) + } + programs, err := tr.Programs(t) + if err != nil { + t.Fatalf("error creating programs: %v", err) + } + tests, err := tr.Tests(t) + if err != nil { + t.Fatalf("error creating tests: %v", err) + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + err := tr.ExecuteTest(t, programs, test) + if err != nil { + t.Fatalf("error executing test: %v", err) + } + }) + } +} + +func testRunnerOptions(testRunnerOpts []TestRunnerOption, testCompilerOpts ...any) []TestRunnerOption { + compilerOpt := testRunnerCompilerFromFlags(testCompilerOpts...) + testSuiteParserOpt := DefaultTestSuiteParser(testSuitePath) + fileDescriptorSetOpt := AddFileDescriptorSet(fileDescriptorSetPath) + testRunnerExprOpt := testRunnerExpressionsFromFlags() + return append([]TestRunnerOption{compilerOpt, testSuiteParserOpt, fileDescriptorSetOpt, testRunnerExprOpt}, testRunnerOpts...) +} + +func testRunnerCompilerFromFlags(testCompilerOpts ...any) TestRunnerOption { + var opts []any + if fileDescriptorSetPath != "" { + opts = append(opts, compiler.TypeDescriptorSetFile(fileDescriptorSetPath)) + } + if baseConfigPath != "" { + opts = append(opts, compiler.EnvironmentFile(baseConfigPath)) + } + if configPath != "" { + opts = append(opts, compiler.EnvironmentFile(configPath)) + } + opts = append(opts, testCompilerOpts...) + return func(tr *TestRunner) (*TestRunner, error) { + c, err := compiler.NewCompiler(opts...) + if err != nil { + return nil, err + } + tr.Compiler = c + return tr, nil + } +} + +func testRunnerExpressionsFromFlags() TestRunnerOption { + return func(tr *TestRunner) (*TestRunner, error) { + if celExpression != "" { + tr.Expressions = append(tr.Expressions, &compiler.CompiledExpression{Path: celExpression}) + tr.Expressions = append(tr.Expressions, &compiler.FileExpression{Path: celExpression}) + tr.Expressions = append(tr.Expressions, &compiler.RawExpression{Value: celExpression}) + } + return tr, nil + } +} + +// TestSuiteParser is an interface for parsing a test suite: +// - ParseTextproto: Returns a cel.spec.expr.conformance.test.TestSuite message. +// - ParseYAML: Returns a test.Suite object. +// In case the test suite is serialized in a Textproto/YAML file, the path of the file is passed as +// an argument to the parse method. +type TestSuiteParser interface { + ParseTextproto(string) (*conformancepb.TestSuite, error) + ParseYAML(string) (*test.Suite, error) +} + +type tsParser struct { + TestSuiteParser +} + +// ParseTextproto parses a test suite file in Textproto format. +func (p *tsParser) ParseTextproto(path string) (*conformancepb.TestSuite, error) { + if path == "" { + return nil, nil + } + if fileFormat := compiler.InferFileFormat(path); fileFormat != compiler.TextProto { + return nil, fmt.Errorf("invalid file extension wanted: .textproto: found %v", fileFormat) + } + testSuite := &conformancepb.TestSuite{} + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("os.ReadFile(%q) failed: %v", path, err) + } + err = prototext.Unmarshal(data, testSuite) + return testSuite, err +} + +// ParseYAML parses a test suite file in YAML format. +func (p *tsParser) ParseYAML(path string) (*test.Suite, error) { + if path == "" { + return nil, nil + } + if fileFormat := compiler.InferFileFormat(path); fileFormat != compiler.TextYAML { + return nil, fmt.Errorf("invalid file extension wanted: .yaml: found %v", fileFormat) + } + testSuiteBytes, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("os.ReadFile(%q) failed: %v", path, err) + } + testSuite := &test.Suite{} + err = yaml.Unmarshal(testSuiteBytes, testSuite) + return testSuite, err +} + +// DefaultTestSuiteParser returns a TestRunnerOption which configures the test runner with a test suite parser. +func DefaultTestSuiteParser(path string) TestRunnerOption { + return func(tr *TestRunner) (*TestRunner, error) { + if path == "" { + return tr, nil + } + tr.TestSuiteFilePath = path + tr.testSuiteParser = &tsParser{} + return tr, nil + } +} + +// TestRunner provides a structure to hold the different components required to execute tests for +// a list of Input Expressions. The TestRunner can be configured with the following options: +// - Compiler: The compiler used for parsing and compiling the input expressions. +// - Input Expressions: The list of input expressions to be tested. +// - Test Suite File Path: The path to the test suite file. +// - File Descriptor Set Path: The path to the file descriptor set file. +// - test Suite Parser: A parser for a test suite file serialized in Textproto/YAML format. +// +// The TestRunner provides the following methods: +// - Programs: Creates a list of CEL programs from the input expressions. +// - Tests: Creates a list of tests from the test suite file. +// - ExecuteTest: Executes a single +type TestRunner struct { + compiler.Compiler + Expressions []compiler.InputExpression + TestSuiteFilePath string + FileDescriptorSetPath string + testSuiteParser TestSuiteParser +} + +// Test represents a single test case to be executed. It encompasses the following: +// - name: The name of the test case. +// - input: The input to be used for evaluating the CEL expression. +// - resultMatcher: A function that takes in the result of evaluating the CEL expression and +// returns a TestResult. +type Test struct { + name string + input interpreter.Activation + resultMatcher func(ref.Val, error) TestResult +} + +// NewTest creates a new Test with the provided name, input and result matcher. +func NewTest(name string, input interpreter.Activation, resultMatcher func(ref.Val, error) TestResult) *Test { + return &Test{ + name: name, + input: input, + resultMatcher: resultMatcher, + } +} + +// TestResult represents the result of a test case execution. It contains the validation result +// along with the expected result and any errors encountered during the execution. +// - Success: Whether the result matcher condition validating the test case was satisfied. +// - Wanted: The expected result of the test case. +// - Error: Any error encountered during the execution. +type TestResult struct { + Success bool + Wanted string + Error error +} + +// NewTestRunner creates a Test Runner with the provided options. +// The options can be used to: +// - configure the Compiler used for parsing and compiling the input expressions +// - configure the Test Runner used for parsing and executing the tests +func NewTestRunner(opts ...TestRunnerOption) (*TestRunner, error) { + tr := &TestRunner{} + var err error + for _, opt := range opts { + tr, err = opt(tr) + if err != nil { + return nil, err + } + } + return tr, nil +} + +// AddFileDescriptorSet creates a Test Runner Option which adds a file descriptor set to the test +// runner. The file descriptor set is used to register proto messages in the global proto registry. +func AddFileDescriptorSet(path string) TestRunnerOption { + return func(tr *TestRunner) (*TestRunner, error) { + if path != "" { + tr.FileDescriptorSetPath = path + } + return tr, nil + } +} + +func registerMessages(path string) error { + if path == "" { + return nil + } + fds, err := fileDescriptorSet(path) + if err != nil { + return err + } + for _, file := range fds.GetFile() { + reflectFD, err := protodesc.NewFile(file, protoregistry.GlobalFiles) + if err != nil { + return fmt.Errorf("protodesc.NewFile(%q) failed: %v", file.GetName(), err) + } + if _, err := protoregistry.GlobalFiles.FindFileByPath(reflectFD.Path()); err == nil { + continue + } + err = protoregistry.GlobalFiles.RegisterFile(reflectFD) + if err != nil { + return fmt.Errorf("protoregistry.GlobalFiles.RegisterFile() failed: %v", err) + } + for i := 0; i < reflectFD.Messages().Len(); i++ { + msg := reflectFD.Messages().Get(i) + msgType := dynamicpb.NewMessageType(msg) + err = protoregistry.GlobalTypes.RegisterMessage(msgType) + if err != nil && !strings.Contains(err.Error(), "already registered") { + return fmt.Errorf("protoregistry.GlobalTypes.RegisterMessage(%q) failed: %v", msgType, err) + } + } + } + return nil +} + +func fileDescriptorSet(path string) (*descpb.FileDescriptorSet, error) { + bytes, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("failed to read file descriptor set file %q: %v", fileDescriptorSetPath, err) + } + fds := &descpb.FileDescriptorSet{} + if err := proto.Unmarshal(bytes, fds); err != nil { + return nil, fmt.Errorf("failed to unmarshal file descriptor set file %q: %v", fileDescriptorSetPath, err) + } + return fds, nil +} + +// Programs creates a list of CEL programs from the input expressions configured in the test runner +// using the provided program options. +func (tr *TestRunner) Programs(t *testing.T, opts ...cel.ProgramOption) ([]cel.Program, error) { + t.Helper() + if tr.Compiler == nil { + return nil, fmt.Errorf("compiler is not set") + } + e, err := tr.CreateEnv() + if err != nil { + return nil, err + } + var programs []cel.Program + for _, expr := range tr.Expressions { + // TODO: propagate metadata map along with the program instance as a struct. + ast, _, err := expr.CreateAST(tr.Compiler) + if err != nil { + if strings.Contains(err.Error(), "invalid file extension") || + strings.Contains(err.Error(), "invalid raw expression") { + continue + } + return nil, err + } + prg, err := e.Program(ast, opts...) + if err != nil { + return nil, err + } + programs = append(programs, prg) + } + return programs, nil +} + +// Tests creates a list of tests from the test suite file and test suite parser configured in the +// test runner. +func (tr *TestRunner) Tests(t *testing.T) ([]*Test, error) { + if tr.Compiler == nil { + return nil, fmt.Errorf("compiler is not set") + } + if tr.testSuiteParser == nil { + return nil, fmt.Errorf("test suite parser is not set") + } + if testSuite, err := tr.testSuiteParser.ParseYAML(tr.TestSuiteFilePath); err != nil && + !strings.Contains(err.Error(), "invalid file extension") { + return nil, fmt.Errorf("tr.testSuiteParser.ParseYAML(%q) failed: %v", tr.TestSuiteFilePath, err) + } else if testSuite != nil { + return tr.createTestsFromYAML(t, testSuite) + } + err := registerMessages(tr.FileDescriptorSetPath) + if err != nil { + return nil, fmt.Errorf("registerMessages(%q) failed: %v", tr.FileDescriptorSetPath, err) + } + if testSuite, err := tr.testSuiteParser.ParseTextproto(tr.TestSuiteFilePath); err != nil && + !strings.Contains(err.Error(), "invalid file extension") { + return nil, fmt.Errorf("tr.testSuiteParser.ParseTextproto(%q) failed: %v", tr.TestSuiteFilePath, err) + } else if testSuite != nil { + return tr.createTestsFromTextproto(t, testSuite) + } + return nil, nil +} + +func (tr *TestRunner) createTestsFromTextproto(t *testing.T, testSuite *conformancepb.TestSuite) ([]*Test, error) { + var tests []*Test + for _, section := range testSuite.GetSections() { + sectionName := section.GetName() + for _, testCase := range section.GetTests() { + testName := fmt.Sprintf("%s/%s", sectionName, testCase.GetName()) + testInput, err := tr.createTestInputFromPB(t, testCase) + if err != nil { + return nil, err + } + testResultMatcher, err := tr.createResultMatcherFromPB(t, testCase) + if err != nil { + return nil, err + } + tests = append(tests, NewTest(testName, testInput, testResultMatcher)) + } + } + return tests, nil +} + +func (tr *TestRunner) createTestInputFromPB(t *testing.T, testCase *conformancepb.TestCase) (interpreter.Activation, error) { + t.Helper() + input := map[string]any{} + e, err := tr.CreateEnv() + if err != nil { + return nil, err + } + if testCase.GetInputContext() != nil { + if len(testCase.GetInput()) != 0 { + return nil, fmt.Errorf("only one of input and input_context can be provided at a time") + } + switch testInput := testCase.GetInputContext().GetInputContextKind().(type) { + case *conformancepb.InputContext_ContextExpr: + refVal, err := tr.eval(testInput.ContextExpr) + if err != nil { + return nil, fmt.Errorf("eval(%q) failed: %w", testInput.ContextExpr, err) + } + ctx, err := refVal.ConvertToNative( + reflect.TypeOf((*proto.Message)(nil)).Elem()) + if err != nil { + return nil, fmt.Errorf("context variable is not a valid proto: %w", err) + } + return cel.ContextProtoVars(ctx.(proto.Message)) + case *conformancepb.InputContext_ContextMessage: + refVal := e.CELTypeAdapter().NativeToValue(testInput.ContextMessage) + ctx, err := refVal.ConvertToNative(reflect.TypeOf((*proto.Message)(nil)).Elem()) + if err != nil { + return nil, fmt.Errorf("context variable is not a valid proto: %w", err) + } + return cel.ContextProtoVars(ctx.(proto.Message)) + } + } + for k, v := range testCase.GetInput() { + switch v.GetKind().(type) { + case *conformancepb.InputValue_Value: + input[k], err = cel.ProtoAsValue(e.CELTypeAdapter(), v.GetValue()) + if err != nil { + return nil, fmt.Errorf("cel.ProtoAsValue(%q) failed: %w", v, err) + } + case *conformancepb.InputValue_Expr: + input[k], err = tr.eval(v.GetExpr()) + if err != nil { + return nil, fmt.Errorf("eval(%q) failed: %w", v.GetExpr(), err) + } + } + } + return interpreter.NewActivation(input) +} + +func (tr *TestRunner) createResultMatcherFromPB(t *testing.T, testCase *conformancepb.TestCase) (func(ref.Val, error) TestResult, error) { + t.Helper() + if testCase.GetOutput() == nil { + return nil, fmt.Errorf("expected output is nil") + } + successResult := TestResult{Success: true} + e, err := tr.CreateEnv() + if err != nil { + return nil, err + } + switch testOutput := testCase.GetOutput().GetResultKind().(type) { + case *conformancepb.TestOutput_ResultValue: + return func(val ref.Val, err error) TestResult { + want := e.CELTypeAdapter().NativeToValue(testOutput.ResultValue) + if err != nil { + return TestResult{Success: false, Wanted: fmt.Sprintf("simple value %v", want), Error: err} + } + outputVal, err := refValueToExprValue(val) + if err != nil { + return TestResult{Success: false, Wanted: fmt.Sprintf("simple value %v", want), Error: fmt.Errorf("refValueToExprValue(%q) failed: %v", val, err)} + } + testResultVal, err := canonicalValueToV1Alpha1(testOutput.ResultValue) + if err != nil { + return TestResult{Success: false, Wanted: fmt.Sprintf("simple value %v", want), Error: fmt.Errorf("canonicalValueToV1Alpha1(%q) failed: %v", testOutput.ResultValue, err)} + } + testVal := &exprpb.ExprValue{ + Kind: &exprpb.ExprValue_Value{Value: testResultVal}} + + if diff := cmp.Diff(testVal, outputVal, protocmp.Transform(), + protocmp.SortRepeatedFields(&exprpb.MapValue{}, "entries")); diff != "" { + return TestResult{Success: false, Wanted: fmt.Sprintf("simple value %v", want), Error: fmt.Errorf("mismatched test output with diff (-want +got):\n%s", diff)} + } + return successResult + }, nil + case *conformancepb.TestOutput_ResultExpr: + return func(val ref.Val, err error) TestResult { + if err != nil { + return TestResult{Success: false, Error: err} + } + testOut, err := tr.eval(testOutput.ResultExpr) + if err != nil { + return TestResult{Success: false, Error: fmt.Errorf("eval(%q) failed: %v", testOutput.ResultExpr, err)} + } + if optOut, ok := val.(*types.Optional); ok { + if optOut.Equal(types.OptionalNone) == types.True { + if testOut.Equal(types.OptionalNone) != types.True { + return TestResult{Success: false, Wanted: fmt.Sprintf("optional value %v", testOut), Error: fmt.Errorf("policy eval got %v", val)} + } + } else if testOut.Equal(optOut.GetValue()) != types.True { + return TestResult{Success: false, Wanted: fmt.Sprintf("optional value %v", testOut), Error: fmt.Errorf("policy eval got %v", val)} + } + } else if val.Equal(testOut) != types.True { + return TestResult{Success: false, Wanted: fmt.Sprintf("optional value %v", testOut), Error: fmt.Errorf("policy eval got %v", val)} + } + return successResult + }, nil + case *conformancepb.TestOutput_EvalError: + return func(val ref.Val, err error) TestResult { + failureResult := TestResult{Success: false, Wanted: fmt.Sprintf("error %v", testOutput.EvalError)} + if err == nil { + return failureResult + } + // Compare the evaluated error with the expected error message only. + for _, want := range testOutput.EvalError.GetErrors() { + if strings.Contains(err.Error(), want.GetMessage()) { + return successResult + } + } + return failureResult + }, nil + case *conformancepb.TestOutput_Unknown: + // TODO: to implement + } + return nil, nil +} + +func refValueToExprValue(refVal ref.Val) (*exprpb.ExprValue, error) { + if types.IsUnknown(refVal) { + return &exprpb.ExprValue{ + Kind: &exprpb.ExprValue_Unknown{ + Unknown: &exprpb.UnknownSet{ + Exprs: refVal.Value().([]int64), + }, + }}, nil + } + v, err := cel.RefValueToValue(refVal) + if err != nil { + return nil, err + } + return &exprpb.ExprValue{ + Kind: &exprpb.ExprValue_Value{Value: v}}, nil +} + +func canonicalValueToV1Alpha1(val *celpb.Value) (*exprpb.Value, error) { + var v1val exprpb.Value + b, err := prototext.Marshal(val) + if err != nil { + return nil, err + } + if err := prototext.Unmarshal(b, &v1val); err != nil { + return nil, err + } + return &v1val, nil +} + +func (tr *TestRunner) eval(expr string) (ref.Val, error) { + e, err := tr.CreateEnv() + if err != nil { + return nil, err + } + e, err = e.Extend(cel.OptionalTypes()) + if err != nil { + return nil, fmt.Errorf("e.Extend() failed: %v", err) + } + ast, iss := e.Compile(expr) + if iss.Err() != nil { + return nil, fmt.Errorf("e.Compile(%q) failed: %v", expr, iss.Err()) + } + prg, err := e.Program(ast) + if err != nil { + return nil, fmt.Errorf("e.Program(%q) failed: %v", expr, err) + } + out, _, err := prg.Eval(cel.NoVars()) + if err != nil { + return nil, fmt.Errorf("prg.Eval(%q) failed: %v", expr, err) + } + return out, nil +} + +func (tr *TestRunner) createTestsFromYAML(t *testing.T, testSuite *test.Suite) ([]*Test, error) { + var tests []*Test + for _, section := range testSuite.Sections { + for _, testCase := range section.Tests { + testName := fmt.Sprintf("%s/%s", section.Name, testCase.Name) + testInput, err := tr.createTestInput(t, testCase) + if err != nil { + return nil, err + } + testResultMatcher, err := tr.createResultMatcher(t, testCase.Output) + if err != nil { + return nil, err + } + tests = append(tests, NewTest(testName, testInput, testResultMatcher)) + } + } + return tests, nil +} + +func (tr *TestRunner) createTestInput(t *testing.T, testCase *test.Case) (interpreter.Activation, error) { + t.Helper() + if testCase.InputContext != nil && testCase.InputContext.ContextExpr != "" { + if len(testCase.Input) != 0 { + return nil, fmt.Errorf("only one of input and input_context can be provided at a time") + } + contextExpr := testCase.InputContext.ContextExpr + out, err := tr.eval(contextExpr) + if err != nil { + return nil, fmt.Errorf("eval(%q) failed: %w", contextExpr, err) + } + ctx, err := out.ConvertToNative(reflect.TypeOf((*proto.Message)(nil)).Elem()) + if err != nil { + return nil, fmt.Errorf("context variable is not a valid proto: %w", err) + } + return cel.ContextProtoVars(ctx.(proto.Message)) + } + input := map[string]any{} + for k, v := range testCase.Input { + if v.Expr != "" { + val, err := tr.eval(v.Expr) + if err != nil { + return nil, fmt.Errorf("eval(%q) failed: %w", v.Expr, err) + } + input[k] = val + continue + } + input[k] = v.Value + } + return interpreter.NewActivation(input) +} + +func (tr *TestRunner) createResultMatcher(t *testing.T, testOutput *test.Output) (func(ref.Val, error) TestResult, error) { + t.Helper() + e, err := tr.CreateEnv() + if err != nil { + return nil, err + } + successResult := TestResult{Success: true} + if testOutput.Value != nil { + want := e.CELTypeAdapter().NativeToValue(testOutput.Value) + return func(out ref.Val, err error) TestResult { + if err == nil { + if out.Equal(want) == types.True { + return successResult + } + if optOut, ok := out.(*types.Optional); ok { + if optOut.HasValue() && optOut.GetValue().Equal(want) == types.True { + return successResult + } + } + } + return TestResult{Success: false, Wanted: fmt.Sprintf("simple value %v", want), Error: err} + }, nil + } + if testOutput.Expr != "" { + want, err := tr.eval(testOutput.Expr) + if err != nil { + return nil, fmt.Errorf("eval(%q) failed: %w", testOutput.Expr, err) + } + return func(out ref.Val, err error) TestResult { + if err == nil { + if out.Equal(want) == types.True { + return successResult + } + if optOut, ok := out.(*types.Optional); ok { + if optOut.HasValue() && optOut.GetValue().Equal(want) == types.True { + return successResult + } + } + } + return TestResult{Success: false, Wanted: fmt.Sprintf("simple value %v", want), Error: err} + }, nil + } + if testOutput.ErrorSet != nil { + return func(out ref.Val, err error) TestResult { + failureResult := TestResult{Success: false, Wanted: fmt.Sprintf("error %v", testOutput.ErrorSet)} + if err == nil { + return failureResult + } + for _, want := range testOutput.ErrorSet { + if strings.Contains(err.Error(), want) { + return successResult + } + } + return failureResult + }, nil + } + if testOutput.UnknownSet != nil { + // TODO: to implement + } + return nil, nil +} + +// ExecuteTest executes the test case against the provided list of programs and returns an error if +// the test fails. +func (tr *TestRunner) ExecuteTest(t *testing.T, programs []cel.Program, test *Test) error { + t.Helper() + if tr.Compiler == nil { + return fmt.Errorf("compiler is not set") + } + for _, program := range programs { + out, _, err := program.Eval(test.input) + if testResult := test.resultMatcher(out, err); !testResult.Success { + return fmt.Errorf("test: %s \n wanted: %v \n failed: %v", test.name, testResult.Wanted, testResult.Error) + } + } + return nil +} diff --git a/tools/celtest/test_runner_test.go b/tools/celtest/test_runner_test.go new file mode 100644 index 000000000..3530c2c7c --- /dev/null +++ b/tools/celtest/test_runner_test.go @@ -0,0 +1,196 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package celtest provides functions for testing CEL policies and expressions. +package celtest + +import ( + "testing" + + "github.com/google/cel-go/cel" + "github.com/google/cel-go/common/decls" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" + "github.com/google/cel-go/policy" + "github.com/google/cel-go/tools/compiler" + "gopkg.in/yaml.v3" +) + +type testCase struct { + name string + celExpression string + testSuitePath string + fileDescriptorSetPath string + configPath string + opts []any +} + +func setupTests() []*testCase { + testCases := []*testCase{ + { + name: "policy test with custom policy parser", + celExpression: "../../policy/testdata/k8s/policy.yaml", + testSuitePath: "../../policy/testdata/k8s/tests.yaml", + configPath: "../../policy/testdata/k8s/config.yaml", + opts: []any{k8sParserOpts()}, + }, + { + name: "policy test with function binding", + celExpression: "../../policy/testdata/restricted_destinations/policy.yaml", + testSuitePath: "../../policy/testdata/restricted_destinations/tests.yaml", + configPath: "../../policy/testdata/restricted_destinations/config.yaml", + opts: []any{locationCodeEnvOption()}, + }, + { + name: "policy test with custom policy metadata", + celExpression: "testdata/custom_policy.celpolicy", + testSuitePath: "testdata/custom_policy_tests.yaml", + opts: []any{customPolicyParserOption(), compiler.PolicyMetadataEnvOption(ParsePolicyVariables)}, + }, + { + name: "raw expression file test", + celExpression: "testdata/raw_expr.cel", + testSuitePath: "testdata/raw_expr_tests", + configPath: "testdata/config.yaml", + opts: []any{fnEnvOption()}, + }, + { + name: "raw expression test", + celExpression: "'i + fn(j) == 42'", + testSuitePath: "testdata/raw_expr_tests", + configPath: "testdata/config.yaml", + opts: []any{fnEnvOption()}, + }, + } + return testCases +} + +func locationCodeEnvOption() cel.EnvOption { + return cel.Function("locationCode", + cel.Overload("locationCode_string", []*cel.Type{cel.StringType}, cel.StringType, + cel.UnaryBinding(locationCode))) +} + +func locationCode(ip ref.Val) ref.Val { + switch ip.(types.String) { + case "10.0.0.1": + return types.String("us") + case "10.0.0.2": + return types.String("de") + default: + return types.String("ir") + } +} + +func k8sParserOpts() policy.ParserOption { + return func(p *policy.Parser) (*policy.Parser, error) { + p.TagVisitor = policy.K8sTestTagHandler() + return p, nil + } +} + +// TestTriggerTestsCustomPolicy tests the TriggerTestsFromCompiler function for a custom policy +// by providing test runner and compiler options without setting the flag variables. +func TestTriggerTestsWithRunnerOptions(t *testing.T) { + t.Run("test trigger tests custom policy", func(t *testing.T) { + envOpt := compiler.EnvironmentFile("../../policy/testdata/k8s/config.yaml") + testSuiteParser := DefaultTestSuiteParser("../../policy/testdata/k8s/tests.yaml") + testCELPolicy := TestRunnerOption(func(tr *TestRunner) (*TestRunner, error) { + tr.Expressions = append(tr.Expressions, &compiler.FileExpression{ + Path: "../../policy/testdata/k8s/policy.yaml", + }) + return tr, nil + }) + c, err := compiler.NewCompiler(envOpt, k8sParserOpts()) + if err != nil { + t.Fatalf("compiler.NewCompiler() failed: %v", err) + } + compilerOpt := TestRunnerOption(func(tr *TestRunner) (*TestRunner, error) { + tr.Compiler = c + return tr, nil + }) + opts := []TestRunnerOption{compilerOpt, testSuiteParser, testCELPolicy} + TriggerTests(t, opts) + }) +} + +func customPolicyParserOption() policy.ParserOption { + return func(p *policy.Parser) (*policy.Parser, error) { + p.TagVisitor = customTagHandler{TagVisitor: policy.DefaultTagVisitor()} + return p, nil + } +} +func ParsePolicyVariables(metadata map[string]any) cel.EnvOption { + var variables []*decls.VariableDecl + for n, t := range metadata { + variables = append(variables, decls.NewVariable(n, parseCustomPolicyVariableType(t.(string)))) + } + return cel.VariableDecls(variables...) +} + +func parseCustomPolicyVariableType(t string) *types.Type { + switch t { + case "int": + return types.IntType + case "string": + return types.StringType + default: + return types.UnknownType + } +} + +type variableType struct { + VariableName string `yaml:"variable_name"` + VariableType string `yaml:"variable_type"` +} + +type customTagHandler struct { + policy.TagVisitor +} + +func (customTagHandler) PolicyTag(ctx policy.ParserContext, id int64, tagName string, node *yaml.Node, p *policy.Policy) { + switch tagName { + case "variable_types": + var varList []*variableType + if err := node.Decode(&varList); err != nil { + ctx.ReportErrorAtID(id, "invalid yaml variable_types node: %v, error: %w", node, err) + return + } + for _, v := range varList { + p.SetMetadata(v.VariableName, v.VariableType) + } + default: + ctx.ReportErrorAtID(id, "unsupported policy tag: %s", tagName) + } +} + +func fnEnvOption() cel.EnvOption { + return cel.Function("fn", + cel.Overload("fn_int", []*cel.Type{cel.IntType}, cel.IntType, + cel.UnaryBinding(func(in ref.Val) ref.Val { + i := in.(types.Int) + return i / types.Int(2) + }))) +} + +// TestTriggerTests tests different scenarios of the TriggerTestsFromCompiler function. +func TestTriggerTests(t *testing.T) { + for _, tc := range setupTests() { + celExpression = tc.celExpression + testSuitePath = tc.testSuitePath + configPath = tc.configPath + fileDescriptorSetPath = tc.fileDescriptorSetPath + TriggerTests(t, nil, tc.opts...) + } +} diff --git a/tools/compiler/testdata/custom_policy_config.yaml b/tools/celtest/testdata/config.yaml similarity index 83% rename from tools/compiler/testdata/custom_policy_config.yaml rename to tools/celtest/testdata/config.yaml index 7b54a43da..62abcb23e 100644 --- a/tools/compiler/testdata/custom_policy_config.yaml +++ b/tools/celtest/testdata/config.yaml @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -name: "custom_policy_config" -extensions: - - name: "optional" - version: "latest" +name: "simple expression config" +variables: + - name: "i" + type_name: "int" + - name: "j" + type_name: "int" diff --git a/tools/compiler/testdata/custom_policy.celpolicy b/tools/celtest/testdata/custom_policy.celpolicy similarity index 97% rename from tools/compiler/testdata/custom_policy.celpolicy rename to tools/celtest/testdata/custom_policy.celpolicy index 663fcf0a7..3867b26fe 100644 --- a/tools/compiler/testdata/custom_policy.celpolicy +++ b/tools/celtest/testdata/custom_policy.celpolicy @@ -23,3 +23,4 @@ rule: - condition: | variable1 == 1 || variable2 == "known" output: "true" + - output: "false" \ No newline at end of file diff --git a/tools/celtest/testdata/custom_policy_tests.yaml b/tools/celtest/testdata/custom_policy_tests.yaml new file mode 100644 index 000000000..f2b554f83 --- /dev/null +++ b/tools/celtest/testdata/custom_policy_tests.yaml @@ -0,0 +1,42 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +description: "Custom policy tests" +section: + - name: "output true" + tests: + - name: "variable 1 match" + input: + variable1: + value: 1 + output: + value: true + - name: "variable 2 match" + input: + variable1: + value: 2 + variable2: + value: "known" + output: + value: true + - name: "output false" + tests: + - name: "variable mismatch" + input: + variable1: + value: 2 + variable2: + value: "unknown" + output: + value: false diff --git a/tools/celtest/testdata/raw_expr.cel b/tools/celtest/testdata/raw_expr.cel new file mode 100644 index 000000000..63386498f --- /dev/null +++ b/tools/celtest/testdata/raw_expr.cel @@ -0,0 +1 @@ +"'i + fn(j) == 42'" \ No newline at end of file diff --git a/tools/celtest/testdata/raw_expr_tests.yaml b/tools/celtest/testdata/raw_expr_tests.yaml new file mode 100644 index 000000000..547d4319f --- /dev/null +++ b/tools/celtest/testdata/raw_expr_tests.yaml @@ -0,0 +1,34 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +description: "simple expression tests" +section: + - name: "valid" + tests: + - name: "true" + input: + i: + value: 21 + j: + value: 42 + output: + value: true + - name: "false" + input: + i: + value: 22 + j: + value:42 + output: + value: false diff --git a/tools/compiler/BUILD.bazel b/tools/compiler/BUILD.bazel index 0c3e4080b..3b84df5ca 100644 --- a/tools/compiler/BUILD.bazel +++ b/tools/compiler/BUILD.bazel @@ -57,19 +57,16 @@ go_test( ], data = [ ":compiler_testdata", - "//policy:k8s_policy_testdata", + "//policy:testdata", ], embed = [":go_default_library"], deps = [ "//cel:go_default_library", - "//common/decls:go_default_library", "//common/env:go_default_library", - "//common/types:go_default_library", "//ext:go_default_library", "//policy:go_default_library", "@dev_cel_expr//:expr", "@dev_cel_expr//conformance:go_default_library", - "@in_gopkg_yaml_v3//:go_default_library", "@org_golang_google_protobuf//types/known/structpb:go_default_library", ], ) diff --git a/tools/compiler/compiler.go b/tools/compiler/compiler.go index c2263e02f..272df2926 100644 --- a/tools/compiler/compiler.go +++ b/tools/compiler/compiler.go @@ -20,6 +20,7 @@ import ( "fmt" "os" "path/filepath" + "sync" "gopkg.in/yaml.v3" @@ -57,6 +58,22 @@ const ( CELPolicy ) +// ExpressionType is an enum for the type of input expression. +type ExpressionType int + +const ( + // ExpressionTypeUnspecified is used when the expression type is not specified. + ExpressionTypeUnspecified ExpressionType = iota + // CompiledExpressionFile is file containing a checked expression. + CompiledExpressionFile + // PolicyFile is a file containing a CEL policy. + PolicyFile + // ExpressionFile is a file containing a CEL expression. + ExpressionFile + // RawExpressionString is a raw CEL expression string. + RawExpressionString +) + // PolicyMetadataEnvOption represents a function which accepts a policy metadata map and returns an // environment option used to extend the CEL environment. // @@ -90,6 +107,7 @@ type compiler struct { policyCompilerOptions []policy.CompilerOption policyMetadataEnvOptions []PolicyMetadataEnvOption env *cel.Env + doOnce sync.Once } // NewCompiler creates a new compiler with a set of functional options. @@ -114,20 +132,29 @@ func NewCompiler(opts ...any) (Compiler, error) { return nil, fmt.Errorf("unsupported compiler option: %v", opt) } } + c.envOptions = append(c.envOptions, extensionOpt()) return c, nil } +func extensionOpt() cel.EnvOption { + return func(e *cel.Env) (*cel.Env, error) { + envConfig := &env.Config{ + Extensions: []*env.Extension{ + &env.Extension{Name: "optional", Version: "latest"}, + &env.Extension{Name: "bindings", Version: "latest"}, + }, + } + return e.Extend(cel.FromConfig(envConfig, ext.ExtensionOptionFactory)) + } +} + // CreateEnv creates a singleton CEL environment with the configured environment options. func (c *compiler) CreateEnv() (*cel.Env, error) { - if c.env != nil { - return c.env, nil - } - env, err := cel.NewCustomEnv(c.envOptions...) - if err != nil { - return nil, err - } - c.env = env - return c.env, nil + var err error + c.doOnce.Do(func() { + c.env, err = cel.NewCustomEnv(c.envOptions...) + }) + return c.env, err } // CreatePolicyParser creates a policy parser using the optionally configured parser options. @@ -165,7 +192,8 @@ func loadProtoFile(path string, format FileFormat, out protoreflect.ProtoMessage return unmarshaller(data, out) } -func inferFileFormat(path string) FileFormat { +// InferFileFormat infers the file format from the file path. +func InferFileFormat(path string) FileFormat { extension := filepath.Ext(path) switch extension { case ".textproto": @@ -190,7 +218,7 @@ func inferFileFormat(path string) FileFormat { // - Binarypb func EnvironmentFile(path string) cel.EnvOption { return func(e *cel.Env) (*cel.Env, error) { - format := inferFileFormat(path) + format := InferFileFormat(path) if format != TextProto && format != TextYAML && format != BinaryProto { return nil, fmt.Errorf("file extension must be one of .textproto, .yaml, .binarypb: found %v", format) } @@ -403,7 +431,7 @@ func protoDeclToFunction(decl *celpb.Decl) (*env.Function, error) { // The file must be in binary format. func TypeDescriptorSetFile(path string) cel.EnvOption { return func(e *cel.Env) (*cel.Env, error) { - format := inferFileFormat(path) + format := InferFileFormat(path) if format != BinaryProto { return nil, fmt.Errorf("type descriptor must be in binary format") } @@ -438,9 +466,9 @@ type CompiledExpression struct { // - Textproto func (c *CompiledExpression) CreateAST(_ Compiler) (*cel.Ast, map[string]any, error) { var expr exprpb.CheckedExpr - format := inferFileFormat(c.Path) + format := InferFileFormat(c.Path) if format != BinaryProto && format != TextProto { - return nil, nil, fmt.Errorf("file extension must be .binarypb or .textproto: found %v", format) + return nil, nil, fmt.Errorf("invalid file extension wanted: .binarypb or .textproto found: %v", format) } if err := loadProtoFile(c.Path, format, &expr); err != nil { return nil, nil, err @@ -466,13 +494,13 @@ func (f *FileExpression) CreateAST(compiler Compiler) (*cel.Ast, map[string]any, if err != nil { return nil, nil, err } - data, err := loadFile(f.Path) - if err != nil { - return nil, nil, err - } - format := inferFileFormat(f.Path) + format := InferFileFormat(f.Path) switch format { case CELString: + data, err := loadFile(f.Path) + if err != nil { + return nil, nil, err + } src := common.NewStringSource(string(data), f.Path) ast, iss := e.CompileSource(src) if iss.Err() != nil { @@ -480,6 +508,10 @@ func (f *FileExpression) CreateAST(compiler Compiler) (*cel.Ast, map[string]any, } return ast, nil, nil case CELPolicy, TextYAML: + data, err := loadFile(f.Path) + if err != nil { + return nil, nil, err + } src := policy.ByteSource(data, f.Path) parser, err := compiler.CreatePolicyParser() if err != nil { @@ -501,7 +533,7 @@ func (f *FileExpression) CreateAST(compiler Compiler) (*cel.Ast, map[string]any, } return ast, policyMetadata, nil default: - return nil, nil, fmt.Errorf("unsupported file format: %v", format) + return nil, nil, fmt.Errorf("invalid file extension wanted: .cel or .celpolicy or .yaml found: %v", format) } } @@ -526,6 +558,10 @@ func (r *RawExpression) CreateAST(compiler Compiler) (*cel.Ast, map[string]any, if err != nil { return nil, nil, err } + format := InferFileFormat(r.Value) + if format != Unspecified { + return nil, nil, fmt.Errorf("invalid raw expression found file with extension: %v", format) + } ast, iss := e.Compile(r.Value) if iss.Err() != nil { return nil, nil, fmt.Errorf("e.Compile(%q) failed: %w", r.Value, iss.Err()) diff --git a/tools/compiler/compiler_test.go b/tools/compiler/compiler_test.go index d0c9ad0be..4bbad785e 100644 --- a/tools/compiler/compiler_test.go +++ b/tools/compiler/compiler_test.go @@ -19,12 +19,9 @@ import ( "testing" "github.com/google/cel-go/cel" - "github.com/google/cel-go/common/decls" "github.com/google/cel-go/common/env" - "github.com/google/cel-go/common/types" "github.com/google/cel-go/ext" "github.com/google/cel-go/policy" - "gopkg.in/yaml.v3" celpb "cel.dev/expr" configpb "cel.dev/expr/conformance" @@ -75,7 +72,7 @@ func TestEnvironmentFileCompareTextprotoAndYAML(t *testing.T) { for i, v := range protoConfig.Variables { for j, p := range v.TypeDesc.Params { if p.TypeName == "google.protobuf.Any" && - config.Variables[i].TypeDesc.Params[j].TypeName == "dyn" { + config.Variables[i].TypeDesc.Params[j].TypeName == "dyn" { p.TypeName = "dyn" } } @@ -186,10 +183,6 @@ func testEnvProto() *configpb.Environment { }, }, Extensions: []*configpb.Extension{ - { - Name: "optional", - Version: "latest", - }, { Name: "lists", Version: "latest", @@ -401,76 +394,6 @@ func TestFileExpressionCustomPolicyParser(t *testing.T) { }) } -func TestFileExpressionPolicyMetadataOptions(t *testing.T) { - t.Run("test file expression policy metadata options", func(t *testing.T) { - envOpt := EnvironmentFile("testdata/custom_policy_config.yaml") - parserOpt := policy.ParserOption(func(p *policy.Parser) (*policy.Parser, error) { - p.TagVisitor = customTagHandler{TagVisitor: policy.DefaultTagVisitor()} - return p, nil - }) - policyMetadataOpt := PolicyMetadataEnvOption(ParsePolicyVariables) - compilerOpts := []any{envOpt, parserOpt, policyMetadataOpt} - compiler, err := NewCompiler(compilerOpts...) - if err != nil { - t.Fatalf("NewCompiler() failed: %v", err) - } - policyFile := &FileExpression{ - Path: "testdata/custom_policy.celpolicy", - } - ast, _, err := policyFile.CreateAST(compiler) - if err != nil { - t.Fatalf("CreateAST() failed: %v", err) - } - if ast == nil { - t.Fatalf("CreateAST() returned nil ast") - } - }) -} - -func ParsePolicyVariables(metadata map[string]any) cel.EnvOption { - variables := []*decls.VariableDecl{} - for n, t := range metadata { - variables = append(variables, decls.NewVariable(n, parseCustomPolicyVariableType(t.(string)))) - } - return cel.VariableDecls(variables...) -} - -func parseCustomPolicyVariableType(t string) *types.Type { - switch t { - case "int": - return types.IntType - case "string": - return types.StringType - default: - return types.UnknownType - } -} - -type variableType struct { - VariableName string `yaml:"variable_name"` - VariableType string `yaml:"variable_type"` -} - -type customTagHandler struct { - policy.TagVisitor -} - -func (customTagHandler) PolicyTag(ctx policy.ParserContext, id int64, tagName string, node *yaml.Node, p *policy.Policy) { - switch tagName { - case "variable_types": - varList := []*variableType{} - if err := node.Decode(&varList); err != nil { - ctx.ReportErrorAtID(id, "invalid yaml variable_types node: %v, error: %w", node, err) - return - } - for _, v := range varList { - p.SetMetadata(v.VariableName, v.VariableType) - } - default: - ctx.ReportErrorAtID(id, "unsupported policy tag: %s", tagName) - } -} - func TestRawExpressionCreateAst(t *testing.T) { t.Run("test raw expression create ast", func(t *testing.T) { envOpt := EnvironmentFile("testdata/config.yaml") diff --git a/tools/compiler/testdata/config.yaml b/tools/compiler/testdata/config.yaml index 929427bc0..5ba153bbc 100644 --- a/tools/compiler/testdata/config.yaml +++ b/tools/compiler/testdata/config.yaml @@ -39,8 +39,6 @@ stdlib: return: type_name: "bool" extensions: - - name: "optional" - version: "latest" - name: "lists" version: "latest" - name: "sets" diff --git a/tools/go.mod b/tools/go.mod index 392efac8f..ea3a3dbf8 100644 --- a/tools/go.mod +++ b/tools/go.mod @@ -3,9 +3,10 @@ module github.com/google/cel-go/tools go 1.23.0 require ( - cel.dev/expr v0.22.1 + cel.dev/expr v0.23.1 github.com/google/cel-go v0.22.0 github.com/google/cel-go/policy v0.0.0-20250311174852-f5ea07b389a1 + github.com/google/go-cmp v0.6.0 google.golang.org/genproto/googleapis/api v0.0.0-20250311190419-81fb87f6b8bf google.golang.org/protobuf v1.36.5 gopkg.in/yaml.v3 v3.0.1 diff --git a/tools/go.sum b/tools/go.sum index b34becfc8..8e2a54ac3 100644 --- a/tools/go.sum +++ b/tools/go.sum @@ -1,5 +1,5 @@ -cel.dev/expr v0.22.1 h1:xoFEsNh972Yzey8N9TCPx2nDvMN7TMhQEzxLuj/iRrI= -cel.dev/expr v0.22.1/go.mod h1:MrpN08Q+lEBs+bGYdLxxHkZoUSsCp0nSKTs0nTymJgw= +cel.dev/expr v0.23.1 h1:K4KOtPCJQjVggkARsjG9RWXP6O4R73aHeJMa/dmCQQg= +cel.dev/expr v0.23.1/go.mod h1:hLPLo1W4QUmuYdA72RBX06QTs6MXw941piREPl3Yfiw= github.com/antlr4-go/antlr/v4 v4.13.1 h1:SqQKkuVZ+zWkMMNkjy5FZe5mr5WURWnlpmOuzYWrPrQ= github.com/antlr4-go/antlr/v4 v4.13.1/go.mod h1:GKmUxMtwp6ZgGwZSva4eWPC5mS6vUAmOABFgjdkM7Nw= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/vendor/cel.dev/expr/MODULE.bazel b/vendor/cel.dev/expr/MODULE.bazel index c0a631316..85ac9ff61 100644 --- a/vendor/cel.dev/expr/MODULE.bazel +++ b/vendor/cel.dev/expr/MODULE.bazel @@ -8,7 +8,7 @@ bazel_dep( ) bazel_dep( name = "gazelle", - version = "0.36.0", + version = "0.39.1", repo_name = "bazel_gazelle", ) bazel_dep( @@ -35,11 +35,11 @@ bazel_dep( ) bazel_dep( name = "rules_cc", - version = "0.0.9", + version = "0.0.17", ) bazel_dep( name = "rules_go", - version = "0.50.1", + version = "0.53.0", repo_name = "io_bazel_rules_go", ) bazel_dep( @@ -48,7 +48,7 @@ bazel_dep( ) bazel_dep( name = "rules_proto", - version = "6.0.0", + version = "7.0.2", ) bazel_dep( name = "rules_python", @@ -63,7 +63,7 @@ python.toolchain( ) go_sdk = use_extension("@io_bazel_rules_go//go:extensions.bzl", "go_sdk") -go_sdk.download(version = "1.21.1") +go_sdk.download(version = "1.22.0") go_deps = use_extension("@bazel_gazelle//:extensions.bzl", "go_deps") go_deps.from_file(go_mod = "//:go.mod") diff --git a/vendor/modules.txt b/vendor/modules.txt index dfdf1bd13..a34dce8d0 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -1,5 +1,5 @@ -# cel.dev/expr v0.22.1 -## explicit; go 1.21.1 +# cel.dev/expr v0.23.1 +## explicit; go 1.22.0 cel.dev/expr # github.com/antlr4-go/antlr/v4 v4.13.0 ## explicit; go 1.20 From 5de96a55872fc08c6d24da6387af331dbb0cb22e Mon Sep 17 00:00:00 2001 From: l46kok Date: Mon, 21 Apr 2025 13:04:21 -0700 Subject: [PATCH 37/46] Add two var comprehensions to repl (#1167) --- repl/evaluator.go | 17 +++++++++-------- repl/evaluator_test.go | 5 +++-- repl/go.mod | 2 +- repl/main/README.md | 11 ++++++----- 4 files changed, 19 insertions(+), 16 deletions(-) diff --git a/repl/evaluator.go b/repl/evaluator.go index 6befc46a2..0b4cd1a69 100644 --- a/repl/evaluator.go +++ b/repl/evaluator.go @@ -43,14 +43,15 @@ import ( var ( extensionMap = map[string]cel.EnvOption{ - "optional": cel.OptionalTypes(), - "bindings": ext.Bindings(), - "strings": ext.Strings(), - "protos": ext.Protos(), - "math": ext.Math(), - "encoders": ext.Encoders(), - "sets": ext.Sets(), - "lists": ext.Lists(), + "optional": cel.OptionalTypes(), + "bindings": ext.Bindings(), + "strings": ext.Strings(), + "protos": ext.Protos(), + "math": ext.Math(), + "encoders": ext.Encoders(), + "sets": ext.Sets(), + "lists": ext.Lists(), + "two_var_comprehensions": ext.TwoVarComprehensions(), } ) diff --git a/repl/evaluator_test.go b/repl/evaluator_test.go index 5a82c2fa5..ab4717c04 100644 --- a/repl/evaluator_test.go +++ b/repl/evaluator_test.go @@ -709,7 +709,8 @@ func TestProcess(t *testing.T) { expr: "'test'.substring(2) == 'st' && " + "proto.getExt(google.expr.proto2.test.ExampleType{}, google.expr.proto2.test.int32_ext) == 0 && " + "math.greatest(1,2) == 2 && " + - "base64.encode(b'hello') == 'aGVsbG8='", + "base64.encode(b'hello') == 'aGVsbG8=' && " + + "{'key': 1}.exists(k, v, k == 'key' && v == 1)", }, }, wantText: "true : bool", @@ -1040,7 +1041,7 @@ func TestProcessOptionError(t *testing.T) { "'bogus'", }, }, - errorMsg: "extension: Unknown option: 'bogus'. Available options are: ['all', 'bindings', 'encoders', 'lists', 'math', 'optional', 'protos', 'sets', 'strings']", + errorMsg: "extension: Unknown option: 'bogus'. Available options are: ['all', 'bindings', 'encoders', 'lists', 'math', 'optional', 'protos', 'sets', 'strings', 'two_var_comprehensions']", }, } diff --git a/repl/go.mod b/repl/go.mod index 2ac63e68a..969eaec49 100644 --- a/repl/go.mod +++ b/repl/go.mod @@ -3,7 +3,7 @@ module github.com/google/cel-go/repl go 1.22.0 require ( - cel.dev/expr v0.22.1 + cel.dev/expr v0.23.1 github.com/antlr4-go/antlr/v4 v4.13.0 github.com/chzyer/readline v1.5.1 github.com/google/cel-go v0.0.0-00010101000000-000000000000 diff --git a/repl/main/README.md b/repl/main/README.md index 6ad875d27..a2f9e47f5 100644 --- a/repl/main/README.md +++ b/repl/main/README.md @@ -50,7 +50,7 @@ into a protocol buffer text format representing the type-checked AST. `%compile ` -Example: +Example: ``` > %compile 3u @@ -158,14 +158,15 @@ may take string arguments. `--container ` sets the expression container for name resolution. -`--extension ` enables CEL extensions. Valid options are: -`strings`, `protos`, `math`, `encoders`, `optional`, `bindings`, and `all`. +`--extension ` enables CEL extensions. Valid options are: +`strings`, `protos`, `math`, `encoders`, `optional`, `bindings`, +`two_var_comprehensions` and `all` (enables all extensions). `--enable_partial_eval` enables partial evaluations example: -`%option --container 'google.protobuf'` +`%option --container 'google.protobuf'` `%option --extension 'strings'` @@ -175,7 +176,7 @@ example: #### reset -`%reset` drops all options and let expressions, returning the evaluator to a +`%reset` drops all options and let expressions, returning the evaluator to a starting empty state. ### Evaluation Model From 5f44021478e018507246acc3dcb1d2958977228f Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Mon, 21 Apr 2025 16:38:31 -0700 Subject: [PATCH 38/46] Check arg count when validating optFieldSelect (#1168) Add a check for correct call shape before indexing into args in checkOptSelect. Manually constructed ASTs could lead to a panic. --- checker/checker.go | 11 +++++++++++ checker/checker_test.go | 40 ++++++++++++++++++++++++++++++++++++++++ checker/errors.go | 4 ++++ 3 files changed, 55 insertions(+) diff --git a/checker/checker.go b/checker/checker.go index 6824af7a5..a9e04fc26 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -145,6 +145,17 @@ func (c *checker) checkSelect(e ast.Expr) { func (c *checker) checkOptSelect(e ast.Expr) { // Collect metadata related to the opt select call packaged by the parser. call := e.AsCall() + if len(call.Args()) != 2 || call.IsMemberFunction() { + t := "" + if call.IsMemberFunction() { + t = " member call with" + } + c.errors.notAnOptionalFieldSelectionCall(e.ID(), c.location(e), + fmt.Sprintf( + "incorrect signature.%s argument count: %d%s", t, len(call.Args()))) + return + } + operand := call.Args()[0] field := call.Args()[1] fieldName, isString := maybeUnwrapString(field) diff --git a/checker/checker_test.go b/checker/checker_test.go index 23b17f3ab..a2b6d3212 100644 --- a/checker/checker_test.go +++ b/checker/checker_test.go @@ -2595,6 +2595,46 @@ func TestCheckErrorData(t *testing.T) { } } +func TestCheckInvalidOptSelectMember(t *testing.T) { + fac := ast.NewExprFactory() + target := fac.NewStruct(1, "Foo", nil) + arg1 := fac.NewStruct(2, "Foo", nil) + arg2 := fac.NewLiteral(3, types.String("field")) + call := fac.NewMemberCall(4, "_?._", target, arg1, arg2) + + // This is not valid syntax, just for illustration purposes. + src := common.NewTextSource("Foo{}._?._(Foo{}, 'field')") + parsed := ast.NewAST(call, ast.NewSourceInfo(src)) + reg := newTestRegistry(t) + env, err := NewEnv(containers.DefaultContainer, reg) + if err != nil { + t.Fatalf("NewEnv(cont, reg) failed: %v", err) + } + _, iss := Check(parsed, src, env) + if !strings.Contains(iss.ToDisplayString(), "incorrect signature. member call") { + t.Errorf("got %s, wanted 'incorrect signature. member call'", iss.ToDisplayString()) + } +} + +func TestCheckInvalidOptSelectMissingArg(t *testing.T) { + fac := ast.NewExprFactory() + arg1 := fac.NewStruct(1, "Foo", nil) + call := fac.NewCall(2, "_?._", arg1) + + // This is not valid syntax, just for illustration purposes. + src := common.NewTextSource("_?._(Foo{})") + parsed := ast.NewAST(call, ast.NewSourceInfo(src)) + reg := newTestRegistry(t) + env, err := NewEnv(containers.DefaultContainer, reg) + if err != nil { + t.Fatalf("NewEnv(cont, reg) failed: %v", err) + } + _, iss := Check(parsed, src, env) + if !strings.Contains(iss.ToDisplayString(), "incorrect signature. argument count: 1") { + t.Errorf("got %s, wanted 'incorrect signature. argument count: 1'", iss.ToDisplayString()) + } +} + func TestCheckInvalidLiteral(t *testing.T) { fac := ast.NewExprFactory() durLiteral := fac.NewLiteral(1, types.Duration{Duration: time.Second}) diff --git a/checker/errors.go b/checker/errors.go index 8b3bf0b8b..3535440ba 100644 --- a/checker/errors.go +++ b/checker/errors.go @@ -45,6 +45,10 @@ func (e *typeErrors) notAComprehensionRange(id int64, l common.Location, t *type FormatCELType(t)) } +func (e *typeErrors) notAnOptionalFieldSelectionCall(id int64, l common.Location, err string) { + e.errs.ReportErrorAtID(id, l, "unsupported optional field selection: %s", err) +} + func (e *typeErrors) notAnOptionalFieldSelection(id int64, l common.Location, field ast.Expr) { e.errs.ReportErrorAtID(id, l, "unsupported optional field selection: %v", field) } From 9d29b7fe661ceb1e46852a03986b108b16bf0be8 Mon Sep 17 00:00:00 2001 From: "zev.ac" Date: Tue, 22 Apr 2025 12:05:47 +0530 Subject: [PATCH 39/46] Refactoring changes to create a test runner option from passed flags, correct indentation and add package level comment for test (#1165) * corrected indentation and added package comment for /test * fix indentation for required_labels test file * added test runner options from flags as a public method --- policy/testdata/context_pb/tests.yaml | 16 +- policy/testdata/nested_rule7/tests.yaml | 32 +-- policy/testdata/pb/tests.yaml | 28 +-- policy/testdata/required_labels/tests.yaml | 116 +++++----- .../restricted_destinations/base_config.yaml | 4 +- .../restricted_destinations/tests.yaml | 204 +++++++++--------- policy/testdata/unnest/tests.yaml | 48 ++--- test/suite.go | 1 + tools/celtest/test_runner.go | 38 +++- .../celtest/testdata/custom_policy.celpolicy | 2 +- 10 files changed, 256 insertions(+), 233 deletions(-) diff --git a/policy/testdata/context_pb/tests.yaml b/policy/testdata/context_pb/tests.yaml index 80849783b..a7e0116f0 100644 --- a/policy/testdata/context_pb/tests.yaml +++ b/policy/testdata/context_pb/tests.yaml @@ -16,13 +16,13 @@ description: "Protobuf input tests" section: - name: "valid" tests: - - name: "good spec" - context_expr: "test.TestAllTypes{single_int32: 10}" - output: - expr: "optional.none()" + - name: "good spec" + context_expr: "test.TestAllTypes{single_int32: 10}" + output: + expr: "optional.none()" - name: "invalid" tests: - - name: "bad spec" - context_expr: "test.TestAllTypes{single_int32: 11}" - output: - value: "invalid spec, got single_int32=11, wanted <= 10" + - name: "bad spec" + context_expr: "test.TestAllTypes{single_int32: 11}" + output: + value: "invalid spec, got single_int32=11, wanted <= 10" diff --git a/policy/testdata/nested_rule7/tests.yaml b/policy/testdata/nested_rule7/tests.yaml index f740c7639..ec2896878 100644 --- a/policy/testdata/nested_rule7/tests.yaml +++ b/policy/testdata/nested_rule7/tests.yaml @@ -16,27 +16,27 @@ description: "Nested rule tests which explore optional vs non-optional returns" section: - name: "valid" tests: - - name: "x=1" - input: + - name: "x=1" + input: x: value: 1 - output: - expr: "optional.none()" - - name: "x=2" - input: + output: + expr: "optional.none()" + - name: "x=2" + input: x: value: 2 - output: - value: false - - name: "x=3" - input: + output: + value: false + - name: "x=3" + input: x: value: 3 - output: - value: true - - name: "x=4" - input: + output: + value: true + - name: "x=4" + input: x: value: 4 - output: - value: true + output: + value: true diff --git a/policy/testdata/pb/tests.yaml b/policy/testdata/pb/tests.yaml index a39f7b73f..3b1ece5ce 100644 --- a/policy/testdata/pb/tests.yaml +++ b/policy/testdata/pb/tests.yaml @@ -16,19 +16,19 @@ description: "Protobuf input tests" section: - name: "valid" tests: - - name: "good spec" - input: - spec: - expr: > - test.TestAllTypes{single_int32: 10} - output: - expr: "optional.none()" + - name: "good spec" + input: + spec: + expr: > + test.TestAllTypes{single_int32: 10} + output: + expr: "optional.none()" - name: "invalid" tests: - - name: "bad spec" - input: - spec: - expr: > - test.TestAllTypes{single_int32: 11} - output: - value: "invalid spec, got single_int32=11, wanted <= 10" + - name: "bad spec" + input: + spec: + expr: > + test.TestAllTypes{single_int32: 11} + output: + value: "invalid spec, got single_int32=11, wanted <= 10" diff --git a/policy/testdata/required_labels/tests.yaml b/policy/testdata/required_labels/tests.yaml index 2159b1b24..4296c6914 100644 --- a/policy/testdata/required_labels/tests.yaml +++ b/policy/testdata/required_labels/tests.yaml @@ -16,65 +16,65 @@ description: "Required labels conformance tests" section: - name: "valid" tests: - - name: "matching" - input: - spec: - value: - labels: - env: prod - experiment: "group b" - resource: - value: - labels: - env: prod - experiment: "group b" - release: "v0.1.0" - output: - expr: "optional.none()" + - name: "matching" + input: + spec: + value: + labels: + env: prod + experiment: "group b" + resource: + value: + labels: + env: prod + experiment: "group b" + release: "v0.1.0" + output: + expr: "optional.none()" - name: "missing" tests: - - name: "env" - input: - spec: - value: - labels: - env: prod - experiment: "group b" - resource: - value: - labels: - experiment: "group b" - release: "v0.1.0" - output: - value: "missing one or more required labels: [\"env\"]" - - name: "experiment" - input: - spec: - value: - labels: - env: prod - experiment: "group b" - resource: - value: - labels: - env: staging - release: "v0.1.0" - output: - value: "missing one or more required labels: [\"experiment\"]" + - name: "env" + input: + spec: + value: + labels: + env: prod + experiment: "group b" + resource: + value: + labels: + experiment: "group b" + release: "v0.1.0" + output: + value: "missing one or more required labels: [\"env\"]" + - name: "experiment" + input: + spec: + value: + labels: + env: prod + experiment: "group b" + resource: + value: + labels: + env: staging + release: "v0.1.0" + output: + value: "missing one or more required labels: [\"experiment\"]" - name: "invalid" tests: - - name: "env" - input: - spec: - value: - labels: - env: prod - experiment: "group b" - resource: - value: - labels: - env: staging - experiment: "group b" - release: "v0.1.0" - output: - value: "invalid values provided on one or more labels: [\"env\"]" + - name: "env" + input: + spec: + value: + labels: + env: prod + experiment: "group b" + resource: + value: + labels: + env: staging + experiment: "group b" + release: "v0.1.0" + output: + value: "invalid values provided on one or more labels: [\"env\"]" diff --git a/policy/testdata/restricted_destinations/base_config.yaml b/policy/testdata/restricted_destinations/base_config.yaml index 615a8b915..8cf015154 100644 --- a/policy/testdata/restricted_destinations/base_config.yaml +++ b/policy/testdata/restricted_destinations/base_config.yaml @@ -30,8 +30,8 @@ variables: - name: "request" type_name: "map" params: - - type_name: "string" - - type_name: "dyn" + - type_name: "string" + - type_name: "dyn" - name: "resource" type_name: "map" params: diff --git a/policy/testdata/restricted_destinations/tests.yaml b/policy/testdata/restricted_destinations/tests.yaml index e448fb1a9..f7ae36550 100644 --- a/policy/testdata/restricted_destinations/tests.yaml +++ b/policy/testdata/restricted_destinations/tests.yaml @@ -16,107 +16,107 @@ description: Restricted destinations conformance tests. section: - name: "valid" tests: - - name: "ip_allowed" - input: - "spec.origin": - value: "us" - "spec.restricted_destinations": - value: - - "cu" - - "ir" - - "kp" - - "sd" - - "sy" - "destination.ip": - value: "10.0.0.1" - "origin.ip": - value: "10.0.0.1" - request: - value: - auth: - claims: {} - resource: - value: - name: "/company/acme/secrets/doomsday-device" - labels: - location: "us" - output: - value: false # false means unrestricted - - name: "nationality_allowed" - input: - "spec.origin": - value: "us" - "spec.restricted_destinations": - value: - - "cu" - - "ir" - - "kp" - - "sd" - - "sy" - "destination.ip": - value: "10.0.0.1" - request: - value: - auth: - claims: - nationality: "us" - resource: - value: - name: "/company/acme/secrets/doomsday-device" - labels: - location: "us" - output: - value: false + - name: "ip_allowed" + input: + spec.origin: + value: "us" + spec.restricted_destinations: + value: + - "cu" + - "ir" + - "kp" + - "sd" + - "sy" + destination.ip: + value: "10.0.0.1" + origin.ip: + value: "10.0.0.1" + request: + value: + auth: + claims: {} + resource: + value: + name: "/company/acme/secrets/doomsday-device" + labels: + location: "us" + output: + value: false # false means unrestricted + - name: "nationality_allowed" + input: + spec.origin: + value: "us" + spec.restricted_destinations: + value: + - "cu" + - "ir" + - "kp" + - "sd" + - "sy" + destination.ip: + value: "10.0.0.1" + request: + value: + auth: + claims: + nationality: "us" + resource: + value: + name: "/company/acme/secrets/doomsday-device" + labels: + location: "us" + output: + value: false - name: "invalid" tests: - - name: "destination_ip_prohibited" - input: - "spec.origin": - value: "us" - "spec.restricted_destinations": - value: - - "cu" - - "ir" - - "kp" - - "sd" - - "sy" - "destination.ip": - value: "123.123.123.123" - "origin.ip": - value: "10.0.0.1" - request: - value: - auth: - claims: {} - resource: - value: - name: "/company/acme/secrets/doomsday-device" - labels: - location: "us" - output: - value: true # true means restricted - - name: "resource_nationality_prohibited" - input: - "spec.origin": - value: "us" - "spec.restricted_destinations": - value: - - "cu" - - "ir" - - "kp" - - "sd" - - "sy" - "destination.ip": - value: "10.0.0.1" - request: - value: - auth: - claims: - nationality: "us" - resource: - value: - name: "/company/acme/secrets/doomsday-device" - labels: - location: "cu" - output: - value: true + - name: "destination_ip_prohibited" + input: + spec.origin: + value: "us" + spec.restricted_destinations: + value: + - "cu" + - "ir" + - "kp" + - "sd" + - "sy" + destination.ip: + value: "123.123.123.123" + origin.ip: + value: "10.0.0.1" + request: + value: + auth: + claims: {} + resource: + value: + name: "/company/acme/secrets/doomsday-device" + labels: + location: "us" + output: + value: true # true means restricted + - name: "resource_nationality_prohibited" + input: + spec.origin: + value: "us" + spec.restricted_destinations: + value: + - "cu" + - "ir" + - "kp" + - "sd" + - "sy" + destination.ip: + value: "10.0.0.1" + request: + value: + auth: + claims: + nationality: "us" + resource: + value: + name: "/company/acme/secrets/doomsday-device" + labels: + location: "cu" + output: + value: true diff --git a/policy/testdata/unnest/tests.yaml b/policy/testdata/unnest/tests.yaml index 31a8770d7..372132ff9 100644 --- a/policy/testdata/unnest/tests.yaml +++ b/policy/testdata/unnest/tests.yaml @@ -16,37 +16,37 @@ description: "Unnest tests unnesting of comprehension sequences" section: - name: "divisible by 2" tests: - - name: "true" - input: + - name: "true" + input: values: expr: "[4, 6]" - output: - value: "some divisible by 2" - - name: "false" - input: - values: - expr: "[1, 3, 5]" - output: - expr: "optional.none()" - - name: "empty-set" - input: - values: - expr: "[1, 2]" - output: - expr: "optional.none()" + output: + value: "some divisible by 2" + - name: "false" + input: + values: + expr: "[1, 3, 5]" + output: + expr: "optional.none()" + - name: "empty-set" + input: + values: + expr: "[1, 2]" + output: + expr: "optional.none()" - name: "divisible by 4" tests: - - name: "true" - input: + - name: "true" + input: values: expr: "[4, 7]" - output: - value: "at least one divisible by 4" + output: + value: "at least one divisible by 4" - name: "power of 6" tests: - - name: "true" - input: + - name: "true" + input: values: expr: "[6, 7]" - output: - value: "at least one power of 6" + output: + value: "at least one power of 6" diff --git a/test/suite.go b/test/suite.go index 2b499e45d..cbb683542 100644 --- a/test/suite.go +++ b/test/suite.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Package test provides a YAML-serializable test suite for CEL tests. package test // Suite is a collection of tests designed to evaluate the correctness of diff --git a/tools/celtest/test_runner.go b/tools/celtest/test_runner.go index 39f29dea3..bfbd9943f 100644 --- a/tools/celtest/test_runner.go +++ b/tools/celtest/test_runner.go @@ -74,8 +74,8 @@ type TestRunnerOption func(*TestRunner) (*TestRunner, error) // - configure the Compiler used for parsing and compiling the expression // - configure the Test Runner used for parsing and executing the tests func TriggerTests(t *testing.T, testRunnerOpts []TestRunnerOption, testCompilerOpts ...any) { - testRunnerOptions := testRunnerOptions(testRunnerOpts, testCompilerOpts...) - tr, err := NewTestRunner(testRunnerOptions...) + testRunnerOption := TestRunnerOptionsFromFlags(testRunnerOpts, testCompilerOpts...) + tr, err := NewTestRunner(testRunnerOption) if err != nil { t.Fatalf("error creating test runner: %v", err) } @@ -97,12 +97,34 @@ func TriggerTests(t *testing.T, testRunnerOpts []TestRunnerOption, testCompilerO } } -func testRunnerOptions(testRunnerOpts []TestRunnerOption, testCompilerOpts ...any) []TestRunnerOption { - compilerOpt := testRunnerCompilerFromFlags(testCompilerOpts...) - testSuiteParserOpt := DefaultTestSuiteParser(testSuitePath) - fileDescriptorSetOpt := AddFileDescriptorSet(fileDescriptorSetPath) - testRunnerExprOpt := testRunnerExpressionsFromFlags() - return append([]TestRunnerOption{compilerOpt, testSuiteParserOpt, fileDescriptorSetOpt, testRunnerExprOpt}, testRunnerOpts...) +// TestRunnerOptionsFromFlags returns a TestRunnerOption which configures the following attributes +// of the test runner using the parsed flags and the optionally provided test runner and test compiler options: +// - Test compiler - The `file_descriptor_set`, `base_config_path` and `config_path` flags are used +// to set up the test compiler. The optionally provided test compiler options are also used to +// augment the test compiler. +// - Test suite parser - The `test_suite_path` flag is used to set up the test suite parser. +// - File descriptor set path - The value of the `file_descriptor_set` flag is set as the +// File Descriptor Set Path of the test runner. +// - Test expression - The `cel_expr` flag is used to populate the test expressions which need to be +// evaluated by the test runner. +func TestRunnerOptionsFromFlags(testRunnerOpts []TestRunnerOption, testCompilerOpts ...any) TestRunnerOption { + return func(tr *TestRunner) (*TestRunner, error) { + opts := []TestRunnerOption{ + testRunnerCompilerFromFlags(testCompilerOpts...), + DefaultTestSuiteParser(testSuitePath), + AddFileDescriptorSet(fileDescriptorSetPath), + testRunnerExpressionsFromFlags(), + } + opts = append(opts, testRunnerOpts...) + var err error + for _, opt := range opts { + tr, err = opt(tr) + if err != nil { + return nil, err + } + } + return tr, nil + } } func testRunnerCompilerFromFlags(testCompilerOpts ...any) TestRunnerOption { diff --git a/tools/celtest/testdata/custom_policy.celpolicy b/tools/celtest/testdata/custom_policy.celpolicy index 3867b26fe..d1701fa03 100644 --- a/tools/celtest/testdata/custom_policy.celpolicy +++ b/tools/celtest/testdata/custom_policy.celpolicy @@ -23,4 +23,4 @@ rule: - condition: | variable1 == 1 || variable2 == "known" output: "true" - - output: "false" \ No newline at end of file + - output: "false" From 9ed72dd8273e5c43d10036192c17243668c3060e Mon Sep 17 00:00:00 2001 From: "zev.ac" Date: Thu, 24 Apr 2025 22:48:20 +0530 Subject: [PATCH 40/46] fix test runner test cases (#1170) * fix test runner test cases * add empty check on programs and tests creation --- tools/celtest/test_runner.go | 6 ++++++ tools/celtest/test_runner_test.go | 6 +++--- tools/celtest/testdata/raw_expr.cel | 2 +- tools/celtest/testdata/raw_expr_tests.yaml | 2 +- 4 files changed, 11 insertions(+), 5 deletions(-) diff --git a/tools/celtest/test_runner.go b/tools/celtest/test_runner.go index bfbd9943f..2adf3a215 100644 --- a/tools/celtest/test_runner.go +++ b/tools/celtest/test_runner.go @@ -83,10 +83,16 @@ func TriggerTests(t *testing.T, testRunnerOpts []TestRunnerOption, testCompilerO if err != nil { t.Fatalf("error creating programs: %v", err) } + if len(programs) == 0 { + t.Fatalf("no programs created for the provided expressions") + } tests, err := tr.Tests(t) if err != nil { t.Fatalf("error creating tests: %v", err) } + if len(tests) == 0 { + t.Fatalf("no tests found") + } for _, test := range tests { t.Run(test.name, func(t *testing.T) { err := tr.ExecuteTest(t, programs, test) diff --git a/tools/celtest/test_runner_test.go b/tools/celtest/test_runner_test.go index 3530c2c7c..60cdd28d1 100644 --- a/tools/celtest/test_runner_test.go +++ b/tools/celtest/test_runner_test.go @@ -61,14 +61,14 @@ func setupTests() []*testCase { { name: "raw expression file test", celExpression: "testdata/raw_expr.cel", - testSuitePath: "testdata/raw_expr_tests", + testSuitePath: "testdata/raw_expr_tests.yaml", configPath: "testdata/config.yaml", opts: []any{fnEnvOption()}, }, { name: "raw expression test", - celExpression: "'i + fn(j) == 42'", - testSuitePath: "testdata/raw_expr_tests", + celExpression: "i + fn(j) == 42", + testSuitePath: "testdata/raw_expr_tests.yaml", configPath: "testdata/config.yaml", opts: []any{fnEnvOption()}, }, diff --git a/tools/celtest/testdata/raw_expr.cel b/tools/celtest/testdata/raw_expr.cel index 63386498f..4906f7af7 100644 --- a/tools/celtest/testdata/raw_expr.cel +++ b/tools/celtest/testdata/raw_expr.cel @@ -1 +1 @@ -"'i + fn(j) == 42'" \ No newline at end of file +i + fn(j) == 42 \ No newline at end of file diff --git a/tools/celtest/testdata/raw_expr_tests.yaml b/tools/celtest/testdata/raw_expr_tests.yaml index 547d4319f..3cedc7619 100644 --- a/tools/celtest/testdata/raw_expr_tests.yaml +++ b/tools/celtest/testdata/raw_expr_tests.yaml @@ -29,6 +29,6 @@ section: i: value: 22 j: - value:42 + value: 42 output: value: false From ad763826808fa99e8960f6e6ea9f902540c528b9 Mon Sep 17 00:00:00 2001 From: Hari Raghu <1621711+haribalan@users.noreply.github.com> Date: Thu, 24 Apr 2025 21:18:16 -0700 Subject: [PATCH 41/46] Sqrt func (#1166) * Adding math square root function * Adding info to readme on math square root function * Handle overload error and test NaN * adding version 2 --- ext/README.md | 17 +++++++++++++++++ ext/math.go | 47 +++++++++++++++++++++++++++++++++++++++++++++++ ext/math_test.go | 19 +++++++++++++++++++ 3 files changed, 83 insertions(+) diff --git a/ext/README.md b/ext/README.md index 4620204fc..8e3ce0cdd 100644 --- a/ext/README.md +++ b/ext/README.md @@ -356,6 +356,23 @@ Examples: math.isFinite(0.0/0.0) // returns false math.isFinite(1.2) // returns true +### Math.Sqrt + +Introduced at version: 2 + +Returns the square root of the given input as double +Throws error for negative or non-numeric inputs + + math.sqrt() -> + math.sqrt() -> + math.sqrt() -> + +Examples: + + math.sqrt(81) // returns 9.0 + math.sqrt(985.25) // returns 31.388692231439016 + math.sqrt(-15) // returns NaN + ## Protos Protos configure extended macros and functions for proto manipulation. diff --git a/ext/math.go b/ext/math.go index 250246db1..6df8e3773 100644 --- a/ext/math.go +++ b/ext/math.go @@ -325,6 +325,23 @@ import ( // // math.isFinite(0.0/0.0) // returns false // math.isFinite(1.2) // returns true +// +// # Math.Sqrt +// +// Introduced at version: 2 +// +// Returns the square root of the given input as double +// Throws error for negative or non-numeric inputs +// +// math.sqrt() -> +// math.sqrt() -> +// math.sqrt() -> +// +// Examples: +// +// math.sqrt(81) // returns 9.0 +// math.sqrt(985.25) // returns 31.388692231439016 +// math.sqrt(-15) // returns NaN func Math(options ...MathOption) cel.EnvOption { m := &mathLib{version: math.MaxUint32} for _, o := range options { @@ -357,6 +374,9 @@ const ( absFunc = "math.abs" signFunc = "math.sign" + // SquareRoot function + sqrtFunc = "math.sqrt" + // Bitwise functions bitAndFunc = "math.bitAnd" bitOrFunc = "math.bitOr" @@ -548,6 +568,18 @@ func (lib *mathLib) CompileOptions() []cel.EnvOption { ), ) } + if lib.version >= 2 { + opts = append(opts, + cel.Function(sqrtFunc, + cel.Overload("math_sqrt_double", []*cel.Type{cel.DoubleType}, cel.DoubleType, + cel.UnaryBinding(sqrt)), + cel.Overload("math_sqrt_int", []*cel.Type{cel.IntType}, cel.DoubleType, + cel.UnaryBinding(sqrt)), + cel.Overload("math_sqrt_uint", []*cel.Type{cel.UintType}, cel.DoubleType, + cel.UnaryBinding(sqrt)), + ), + ) + } return opts } @@ -691,6 +723,21 @@ func sign(val ref.Val) ref.Val { } } + +func sqrt(val ref.Val) ref.Val { + switch v := val.(type) { + case types.Double: + return types.Double(math.Sqrt(float64(v))) + case types.Int: + return types.Double(math.Sqrt(float64(v))) + case types.Uint: + return types.Double(math.Sqrt(float64(v))) + default: + return types.NewErr("no such overload: sqrt") + } +} + + func bitAndPairInt(first, second ref.Val) ref.Val { l := first.(types.Int) r := second.(types.Int) diff --git a/ext/math_test.go b/ext/math_test.go index 0c82c9811..878954c75 100644 --- a/ext/math_test.go +++ b/ext/math_test.go @@ -181,6 +181,15 @@ func TestMath(t *testing.T) { {expr: "math.abs(1) == 1"}, {expr: "math.abs(-234.5) == 234.5"}, {expr: "math.abs(234.5) == 234.5"}, + + // Tests for Square root function + {expr: "math.sqrt(49.0) == 7.0"}, + {expr: "math.sqrt(0) == 0.0"}, + {expr: "math.sqrt(1) == 1.0"}, + {expr: "math.sqrt(25u) == 5.0"}, + {expr: "math.sqrt(82) == 9.055385138137417"}, + {expr: "math.sqrt(985.25) == 31.388692231439016"}, + {expr: "math.isNaN(math.sqrt(-15.34))"}, } env := testMathEnv(t, @@ -472,6 +481,10 @@ func TestMathRuntimeErrors(t *testing.T) { expr: "math.trunc(dyn(1u))", err: "no such overload: math.trunc(uint)", }, + { + expr: "math.sqrt(dyn(''))", + err: "no such overload: math.sqrt(string)", + }, } env := testMathEnv(t, @@ -599,6 +612,12 @@ func TestMathVersions(t *testing.T) { "bitShiftRight": `math.bitShiftRight(4, 2) == 1`, }, }, + { + version: 2, + supportedFunctions: map[string]string{ + "sqrt": `math.sqrt(25) == 5.0`, + }, + }, } for _, lib := range versionCases { env, err := cel.NewEnv(Math(MathVersion(lib.version))) From 845f2a8ec46a147297bb648edf41cfc2b68fb189 Mon Sep 17 00:00:00 2001 From: l46kok Date: Wed, 30 Apr 2025 17:14:26 -0700 Subject: [PATCH 42/46] Fix lastIndexOf behavior against an empty string in strings extension (#1173) * Fix lastIndexOf behavior against an empty string in strings extension --- ext/strings.go | 4 ++++ ext/strings_test.go | 1 + 2 files changed, 5 insertions(+) diff --git a/ext/strings.go b/ext/strings.go index 88b4d7f03..de65421f6 100644 --- a/ext/strings.go +++ b/ext/strings.go @@ -607,6 +607,10 @@ func lastIndexOf(str, substr string) (int64, error) { if substr == "" { return int64(len(runes)), nil } + + if len(str) < len(substr) { + return -1, nil + } return lastIndexOfOffset(str, substr, int64(len(runes)-1)) } diff --git a/ext/strings_test.go b/ext/strings_test.go index 3a8adeb09..17be558f8 100644 --- a/ext/strings_test.go +++ b/ext/strings_test.go @@ -53,6 +53,7 @@ var stringTests = []struct { {expr: `'hello wello'.indexOf('ello', 6) == 7`}, {expr: `'hello wello'.indexOf('elbo room!!') == -1`}, {expr: `'hello wello'.indexOf('elbo room!!!') == -1`}, + {expr: `''.lastIndexOf('@@') == -1`}, {expr: `'tacocat'.lastIndexOf('') == 7`}, {expr: `'tacocat'.lastIndexOf('at') == 5`}, {expr: `'tacocat'.lastIndexOf('none') == -1`}, From f6d3c92171c2c8732a8d0a4b24d6729df4261520 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Wed, 7 May 2025 14:15:30 -0700 Subject: [PATCH 43/46] Update test runner to avoid using flags when not necessary (#1174) --- tools/BUILD.bazel | 5 +++ tools/celtest/test_runner.go | 58 ++++++++++++++++++------------- tools/celtest/test_runner_test.go | 28 +++++++++++---- 3 files changed, 61 insertions(+), 30 deletions(-) create mode 100644 tools/BUILD.bazel diff --git a/tools/BUILD.bazel b/tools/BUILD.bazel new file mode 100644 index 000000000..0988550c0 --- /dev/null +++ b/tools/BUILD.bazel @@ -0,0 +1,5 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) diff --git a/tools/celtest/test_runner.go b/tools/celtest/test_runner.go index 2adf3a215..4e0cb3b8d 100644 --- a/tools/celtest/test_runner.go +++ b/tools/celtest/test_runner.go @@ -59,7 +59,6 @@ func init() { flag.StringVar(&configPath, "config_path", "", "path to a config file") flag.StringVar(&baseConfigPath, "base_config_path", "", "path to a base config file") flag.StringVar(&celExpression, "cel_expr", "", "CEL expression to test") - flag.Parse() } // TestRunnerOption is used to configure the following attributes of the Test Runner: @@ -73,9 +72,8 @@ type TestRunnerOption func(*TestRunner) (*TestRunner, error) // with the provided set of options. The options can be used to: // - configure the Compiler used for parsing and compiling the expression // - configure the Test Runner used for parsing and executing the tests -func TriggerTests(t *testing.T, testRunnerOpts []TestRunnerOption, testCompilerOpts ...any) { - testRunnerOption := TestRunnerOptionsFromFlags(testRunnerOpts, testCompilerOpts...) - tr, err := NewTestRunner(testRunnerOption) +func TriggerTests(t *testing.T, testRunnerOpts ...TestRunnerOption) { + tr, err := NewTestRunner(testRunnerOpts...) if err != nil { t.Fatalf("error creating test runner: %v", err) } @@ -114,12 +112,15 @@ func TriggerTests(t *testing.T, testRunnerOpts []TestRunnerOption, testCompilerO // - Test expression - The `cel_expr` flag is used to populate the test expressions which need to be // evaluated by the test runner. func TestRunnerOptionsFromFlags(testRunnerOpts []TestRunnerOption, testCompilerOpts ...any) TestRunnerOption { + if !flag.Parsed() { + flag.Parse() + } return func(tr *TestRunner) (*TestRunner, error) { opts := []TestRunnerOption{ testRunnerCompilerFromFlags(testCompilerOpts...), DefaultTestSuiteParser(testSuitePath), AddFileDescriptorSet(fileDescriptorSetPath), - testRunnerExpressionsFromFlags(), + TestExpression(celExpression), } opts = append(opts, testRunnerOpts...) var err error @@ -145,25 +146,7 @@ func testRunnerCompilerFromFlags(testCompilerOpts ...any) TestRunnerOption { opts = append(opts, compiler.EnvironmentFile(configPath)) } opts = append(opts, testCompilerOpts...) - return func(tr *TestRunner) (*TestRunner, error) { - c, err := compiler.NewCompiler(opts...) - if err != nil { - return nil, err - } - tr.Compiler = c - return tr, nil - } -} - -func testRunnerExpressionsFromFlags() TestRunnerOption { - return func(tr *TestRunner) (*TestRunner, error) { - if celExpression != "" { - tr.Expressions = append(tr.Expressions, &compiler.CompiledExpression{Path: celExpression}) - tr.Expressions = append(tr.Expressions, &compiler.FileExpression{Path: celExpression}) - tr.Expressions = append(tr.Expressions, &compiler.RawExpression{Value: celExpression}) - } - return tr, nil - } + return TestCompiler(opts...) } // TestSuiteParser is an interface for parsing a test suite: @@ -293,6 +276,33 @@ func NewTestRunner(opts ...TestRunnerOption) (*TestRunner, error) { return tr, nil } +// TestExpression returns a TestRunnerOption which configures a policy file, expression file, or raw expression +// for testing +func TestExpression(value string) TestRunnerOption { + return func(tr *TestRunner) (*TestRunner, error) { + if value != "" { + tr.Expressions = append(tr.Expressions, + &compiler.CompiledExpression{Path: value}, + &compiler.FileExpression{Path: value}, + &compiler.RawExpression{Value: value}, + ) + } + return tr, nil + } +} + +// TestCompiler configures a compiler to use for testing. +func TestCompiler(compileOpts ...any) TestRunnerOption { + return func(tr *TestRunner) (*TestRunner, error) { + c, err := compiler.NewCompiler(compileOpts...) + if err != nil { + return nil, err + } + tr.Compiler = c + return tr, nil + } +} + // AddFileDescriptorSet creates a Test Runner Option which adds a file descriptor set to the test // runner. The file descriptor set is used to register proto messages in the global proto registry. func AddFileDescriptorSet(path string) TestRunnerOption { diff --git a/tools/celtest/test_runner_test.go b/tools/celtest/test_runner_test.go index 60cdd28d1..c434d6b26 100644 --- a/tools/celtest/test_runner_test.go +++ b/tools/celtest/test_runner_test.go @@ -121,7 +121,7 @@ func TestTriggerTestsWithRunnerOptions(t *testing.T) { return tr, nil }) opts := []TestRunnerOption{compilerOpt, testSuiteParser, testCELPolicy} - TriggerTests(t, opts) + TriggerTests(t, opts...) }) } @@ -187,10 +187,26 @@ func fnEnvOption() cel.EnvOption { // TestTriggerTests tests different scenarios of the TriggerTestsFromCompiler function. func TestTriggerTests(t *testing.T) { for _, tc := range setupTests() { - celExpression = tc.celExpression - testSuitePath = tc.testSuitePath - configPath = tc.configPath - fileDescriptorSetPath = tc.fileDescriptorSetPath - TriggerTests(t, nil, tc.opts...) + tc := tc + t.Run(tc.name, func(t *testing.T) { + var testOpts []TestRunnerOption + var compileOpts []any = make([]any, 0, len(tc.opts)+2) + for _, opt := range tc.opts { + compileOpts = append(compileOpts, opt) + } + if tc.fileDescriptorSetPath != "" { + compileOpts = append(compileOpts, compiler.TypeDescriptorSetFile(tc.fileDescriptorSetPath)) + } + if tc.configPath != "" { + compileOpts = append(compileOpts, compiler.EnvironmentFile(tc.configPath)) + } + testOpts = append(testOpts, + TestCompiler(compileOpts...), + DefaultTestSuiteParser(tc.testSuitePath), + AddFileDescriptorSet(tc.fileDescriptorSetPath), + TestExpression(tc.celExpression), + ) + TriggerTests(t, testOpts...) + }) } } From b1209b87fefb28ed128271cfa2dfbfb5a25fb806 Mon Sep 17 00:00:00 2001 From: "zev.ac" Date: Tue, 20 May 2025 23:47:43 +0530 Subject: [PATCH 44/46] Add bazel rule to trigger cel tests and return policy metadata while creating CEL programs (#1176) * add bazel rule to trigger cel test --- common/decls/decls.go | 2 +- policy/BUILD.bazel | 124 ++++++++++++++++++++++++ policy/test/BUILD.bazel | 42 ++++++++ policy/test/cel_test_runner.go | 51 ++++++++++ policy/test/k8s_cel_test_runner.go | 39 ++++++++ test/cel_go_test.bzl | 150 +++++++++++++++++++++++++++++ tools/celtest/BUILD.bazel | 12 +++ tools/celtest/test_runner.go | 91 +++++++++++++++-- tools/celtest/test_runner_test.go | 3 +- tools/compiler/compiler.go | 2 +- tools/compiler/compiler_test.go | 2 +- tools/go.mod | 1 + tools/go.sum | 2 + 13 files changed, 507 insertions(+), 14 deletions(-) create mode 100644 policy/test/BUILD.bazel create mode 100644 policy/test/cel_test_runner.go create mode 100644 policy/test/k8s_cel_test_runner.go create mode 100644 test/cel_go_test.bzl diff --git a/common/decls/decls.go b/common/decls/decls.go index 759b1d16b..a4a51c3f2 100644 --- a/common/decls/decls.go +++ b/common/decls/decls.go @@ -303,7 +303,7 @@ func (f *FunctionDecl) OverloadDecls() []*OverloadDecl { return overloads } -// Returns true if the function has late bindings. A function cannot mix late bindings with other bindings. +// HasLateBinding returns true if the function has late bindings. A function cannot mix late bindings with other bindings. func (f *FunctionDecl) HasLateBinding() bool { if f == nil { return false diff --git a/policy/BUILD.bazel b/policy/BUILD.bazel index 15facc55d..02a856b3d 100644 --- a/policy/BUILD.bazel +++ b/policy/BUILD.bazel @@ -13,6 +13,8 @@ # limitations under the License. load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("@rules_proto//proto:defs.bzl", "proto_descriptor_set") +load("//test:cel_go_test.bzl", "cel_go_test") package( default_visibility = [ @@ -76,3 +78,125 @@ filegroup( name = "testdata", srcs = glob(["testdata/**"]), ) + +proto_descriptor_set( + name = "test_all_types_fds", + deps = [ + "//test/proto3pb:test_all_types_proto", + ], +) + +cel_go_test( + name = "context_pb_policy", + cel_expr = "testdata/context_pb/policy.yaml", + config = "testdata/context_pb/config.yaml", + file_descriptor_set = ":test_all_types_fds", + test_src = "//policy/test:cel_test_runner.go", + test_suite = "testdata/context_pb/tests.yaml", +) + +cel_go_test( + name = "k8s_policy", + cel_expr = "testdata/k8s/policy.yaml", + config = "testdata/k8s/config.yaml", + deps = ["//policy:go_default_library"], + test_src = "//policy/test:k8s_cel_test_runner.go", + test_suite = "testdata/k8s/tests.yaml", +) + +cel_go_test( + name = "limits_policy", + cel_expr = "testdata/limits/policy.yaml", + config = "testdata/limits/config.yaml", + test_src = "//policy/test:cel_test_runner.go", + test_suite = "testdata/limits/tests.yaml", +) + +cel_go_test( + name = "nested_rule_policy", + cel_expr = "testdata/nested_rule/policy.yaml", + config = "testdata/nested_rule/config.yaml", + test_src = "//policy/test:cel_test_runner.go", + test_suite = "testdata/nested_rule/tests.yaml", +) + +cel_go_test( + name = "nested_rule2_policy", + cel_expr = "testdata/nested_rule2/policy.yaml", + config = "testdata/nested_rule2/config.yaml", + test_src = "//policy/test:cel_test_runner.go", + test_suite = "testdata/nested_rule2/tests.yaml", +) + +cel_go_test( + name = "nested_rule3_policy", + cel_expr = "testdata/nested_rule3/policy.yaml", + config = "testdata/nested_rule3/config.yaml", + test_src = "//policy/test:cel_test_runner.go", + test_suite = "testdata/nested_rule3/tests.yaml", +) + +cel_go_test( + name = "nested_rule4_policy", + cel_expr = "testdata/nested_rule4/policy.yaml", + config = "testdata/nested_rule4/config.yaml", + test_src = "//policy/test:cel_test_runner.go", + test_suite = "testdata/nested_rule4/tests.yaml", +) + +cel_go_test( + name = "nested_rule5_policy", + cel_expr = "testdata/nested_rule5/policy.yaml", + config = "testdata/nested_rule5/config.yaml", + test_src = "//policy/test:cel_test_runner.go", + test_suite = "testdata/nested_rule5/tests.yaml", +) + +cel_go_test( + name = "nested_rule6_policy", + cel_expr = "testdata/nested_rule6/policy.yaml", + config = "testdata/nested_rule6/config.yaml", + test_src = "//policy/test:cel_test_runner.go", + test_suite = "testdata/nested_rule6/tests.yaml", +) + +cel_go_test( + name = "nested_rule7_policy", + cel_expr = "testdata/nested_rule7/policy.yaml", + config = "testdata/nested_rule7/config.yaml", + test_src = "//policy/test:cel_test_runner.go", + test_suite = "testdata/nested_rule7/tests.yaml", +) + +cel_go_test( + name = "pb_policy", + cel_expr = "testdata/pb/policy.yaml", + config = "testdata/pb/config.yaml", + file_descriptor_set = ":test_all_types_fds", + test_src = "//policy/test:cel_test_runner.go", + test_suite = "testdata/pb/tests.yaml", +) + +cel_go_test( + name = "required_labels_policy", + cel_expr = "testdata/required_labels/policy.yaml", + config = "testdata/required_labels/config.yaml", + test_src = "//policy/test:cel_test_runner.go", + test_suite = "testdata/required_labels/tests.yaml", +) + +cel_go_test( + name = "restricted_destinations_policy", + cel_expr = "testdata/restricted_destinations/policy.yaml", + config = "testdata/restricted_destinations/config.yaml", + test_src = "//policy/test:cel_test_runner.go", + test_suite = "testdata/restricted_destinations/tests.yaml", +) + +cel_go_test( + name = "unnest_policy", + cel_expr = "testdata/unnest/policy.yaml", + config = "testdata/unnest/config.yaml", + test_src = "//policy/test:cel_test_runner.go", + test_suite = "testdata/unnest/tests.yaml", +) diff --git a/policy/test/BUILD.bazel b/policy/test/BUILD.bazel new file mode 100644 index 000000000..ec12ca26a --- /dev/null +++ b/policy/test/BUILD.bazel @@ -0,0 +1,42 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +package( + default_visibility = [ + "//policy:__subpackages__" + ], + licenses = ["notice"], +) + +exports_files([ + "cel_test_runner.go", +]) + +go_library( + name = "test", + testonly = True, + srcs = [ + "cel_test_runner.go", + "k8s_cel_test_runner.go", + ], + deps = [ + "//cel:go_default_library", + "//common/types:go_default_library", + "//common/types/ref:go_default_library", + "//policy:go_default_library", + "//tools/celtest:go_default_library", + ], +) diff --git a/policy/test/cel_test_runner.go b/policy/test/cel_test_runner.go new file mode 100644 index 000000000..95269e13c --- /dev/null +++ b/policy/test/cel_test_runner.go @@ -0,0 +1,51 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "os" + "testing" + + "github.com/google/cel-go/cel" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" + "github.com/google/cel-go/tools/celtest" +) + +// TestCEL triggers the celtest test runner with a list of custom options which are used to set up +// the compiler tool and test runner. +func TestCEL(t *testing.T) { + opts := []any{locationCodeEnvOption()} + testResourcesDir := os.Getenv("RUNFILES_DIR") + testRunnerOpt := celtest.TestRunnerOptionsFromFlags(testResourcesDir, nil, opts...) + celtest.TriggerTests(t, testRunnerOpt) +} + +func locationCodeEnvOption() cel.EnvOption { + return cel.Function("locationCode", + cel.Overload("locationCode_string", []*cel.Type{cel.StringType}, cel.StringType, + cel.UnaryBinding(locationCode))) +} + +func locationCode(ip ref.Val) ref.Val { + switch ip.(types.String) { + case "10.0.0.1": + return types.String("us") + case "10.0.0.2": + return types.String("de") + default: + return types.String("ir") + } +} diff --git a/policy/test/k8s_cel_test_runner.go b/policy/test/k8s_cel_test_runner.go new file mode 100644 index 000000000..26e6ad114 --- /dev/null +++ b/policy/test/k8s_cel_test_runner.go @@ -0,0 +1,39 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "os" + "testing" + + "github.com/google/cel-go/policy" + "github.com/google/cel-go/tools/celtest" +) + +// TestK8sCEL triggers compilation and test execution of a k8s policy which +// contains custom policy tags. Custom parser options are used to configure the +// parser to handle the custom policy tags. Tests are triggered with a list of +// custom CEL environment options. +func TestK8sCEL(t *testing.T) { + parserOpt := policy.ParserOption(testK8sPolicyParser) + testResourcesDir := os.Getenv("RUNFILES_DIR") + testRunnerOpt := celtest.TestRunnerOptionsFromFlags(testResourcesDir, nil, parserOpt) + celtest.TriggerTests(t, testRunnerOpt) +} + +func testK8sPolicyParser(p *policy.Parser) (*policy.Parser, error) { + p.TagVisitor = policy.K8sTestTagHandler() + return p, nil +} diff --git a/test/cel_go_test.bzl b/test/cel_go_test.bzl new file mode 100644 index 000000000..35cad9d4a --- /dev/null +++ b/test/cel_go_test.bzl @@ -0,0 +1,150 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Starlark rule for triggering unit tests on CEL policies and expressions for go +runtime target. +""" + +load("@bazel_skylib//lib:paths.bzl", "paths") +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +def cel_go_test( + name, + test_suite = "", + config = "", + test_src = "", + cel_expr = "", + is_raw_expr = False, + base_config = "", + test_data_path = "", + file_descriptor_set = "", + filegroup = "", + deps = [], + data = [], + **kwargs): + """trigger tests for a CEL checked expression + + This macro wraps an invocation of the go_test rule and is used to trigger tests for a CEL + checked expression. + + Args: + name: str name for the generated artifact + test_suite: str label of a file containing a cel.expr.conformance.test.TestSuite message in + Textproto format. Alternatively, the file can also contain a test.Suite object in YAML format. + config: str label of a file containing a google.api.expr.conformance.Environment message in + Textproto format. Alternatively, the file can also contain the environment configuration represented + as an env.Config object in YAML format. + base_config: str label of a file containing a google.api.expr.conformance.Environment message + in Textproto format. Alternatively, the file can also contain the environment configuration represented + as an env.Config object in YAML format. + test_src: source file containing the invocaton of the test runner. + cel_expr: str label of a file containing either a CEL policy or a CEL expression or a checked + expression or a raw CEL expression string. + is_raw_expr: boolean indicating if the cel_expr is a raw CEL expression string. + test_data_path: path of the directory containing the test files. This is needed only + if the test files are not located in the same directory as the BUILD file. + file_descriptor_set: str label or filename pointing to a file_descriptor_set message. Note: + this must be in binary format with either a .binarypb or .pb or.fds extension. If you need + to support a textformat file_descriptor_set, embed it in the environment file. (default None) + filegroup: str label of a filegroup containing the test suite, config, and cel_expr. + deps: list of dependencies for the go_test rule + data: list of data dependencies for the go_test rule + **kwargs: additional arguments to pass to the go_test rule + """ + + _, cel_expr_format = paths.split_extension(cel_expr) + if filegroup != "": + data = data + [filegroup] + elif test_data_path != "" and test_data_path != native.package_name(): + if config != "": + data = data + [test_data_path + ":" + config] + if base_config != "": + data = data + [test_data_path + ":" + base_config] + if test_suite != "": + data = data + [test_data_path + ":" + test_suite] + if cel_expr_format == ".cel" or cel_expr_format == ".celpolicy" or cel_expr_format == ".yaml": + data = data + [test_data_path + ":" + cel_expr] + else: + test_data_path = native.package_name() + if config != "": + data = data + [config] + if base_config != "": + data = data + [base_config] + if test_suite != "": + data = data + [test_suite] + if cel_expr_format == ".cel" or cel_expr_format == ".celpolicy" or cel_expr_format == ".yaml": + data = data + [cel_expr] + + test_data_path = test_data_path.lstrip("/") + + if test_suite != "": + test_suite = test_data_path + "/" + test_suite + + if config != "": + config = test_data_path + "/" + config + + if base_config != "": + base_config = test_data_path + "/" + base_config + + srcs = [test_src] + + args = [ + "--test_suite_path=%s" % test_suite, + "--config_path=%s" % config, + "--base_config_path=%s" % base_config, + ] + + if cel_expr_format == ".cel" or cel_expr_format == ".celpolicy" or cel_expr_format == ".yaml": + args.append("--cel_expr=%s" % test_data_path + "/" + cel_expr) + elif is_raw_expr: + data = data + [cel_expr] + args.append("--cel_expr=%s" % cel_expr) + else: + args.append("--cel_expr=$(location {})".format(cel_expr)) + + if file_descriptor_set != "": + data = data + [file_descriptor_set] + args.append("--file_descriptor_set=$(location {})".format(file_descriptor_set)) + + go_test( + name = name, + srcs = srcs, + args = args, + data = data, + deps = deps + [ + "//cel:go_default_library", + "//common/types:go_default_library", + "//common/types/ref:go_default_library", + "//interpreter:go_default_library", + "//test:go_default_library", + "//tools/celtest:go_default_library", + "//tools/compiler:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", + "@dev_cel_expr//:expr", + "@dev_cel_expr//conformance/test:go_default_library", + "@in_gopkg_yaml_v3//:go_default_library", + "@io_bazel_rules_go//go/runfiles", + "@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library", + "@org_golang_google_protobuf//encoding/prototext:go_default_library", + "@org_golang_google_protobuf//proto:go_default_library", + "@org_golang_google_protobuf//reflect/protodesc:go_default_library", + "@org_golang_google_protobuf//reflect/protoreflect:go_default_library", + "@org_golang_google_protobuf//reflect/protoregistry:go_default_library", + "@org_golang_google_protobuf//testing/protocmp:go_default_library", + "@org_golang_google_protobuf//types/descriptorpb:go_default_library", + "@org_golang_google_protobuf//types/dynamicpb:go_default_library", + ], + **kwargs + ) diff --git a/tools/celtest/BUILD.bazel b/tools/celtest/BUILD.bazel index 295f4e755..d6fe8d651 100644 --- a/tools/celtest/BUILD.bazel +++ b/tools/celtest/BUILD.bazel @@ -13,6 +13,7 @@ # limitations under the License. load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("//test:cel_go_test.bzl", "cel_go_test") package( default_visibility = ["//visibility:public"], @@ -74,4 +75,15 @@ go_test( filegroup( name = "testdata", srcs = glob(["testdata/**"]), +) + +# trigger cel_go_test with resource files located under test_data_path +cel_go_test( + name = "pb_policy", + cel_expr = "testdata/pb/policy.yaml", + config = "testdata/pb/config.yaml", + file_descriptor_set = "//policy:test_all_types_fds", + test_data_path = "//policy", + test_src = "//policy/test:cel_test_runner.go", + test_suite = "testdata/pb/tests.yaml", ) \ No newline at end of file diff --git a/tools/celtest/test_runner.go b/tools/celtest/test_runner.go index 4e0cb3b8d..db5f765b4 100644 --- a/tools/celtest/test_runner.go +++ b/tools/celtest/test_runner.go @@ -19,6 +19,7 @@ import ( "flag" "fmt" "os" + "path/filepath" "reflect" "strings" "testing" @@ -61,6 +62,61 @@ func init() { flag.StringVar(&celExpression, "cel_expr", "", "CEL expression to test") } +func updateRunfilesPathForFlags(testResourcesDir string) error { + if testResourcesDir == "" { + return nil + } + paths := make([]*string, 0, 5) + if compiler.InferFileFormat(testSuitePath) != compiler.Unspecified { + paths = append(paths, &testSuitePath) + } + if compiler.InferFileFormat(fileDescriptorSetPath) != compiler.Unspecified { + paths = append(paths, &fileDescriptorSetPath) + } + if compiler.InferFileFormat(configPath) != compiler.Unspecified { + paths = append(paths, &configPath) + } + if compiler.InferFileFormat(baseConfigPath) != compiler.Unspecified { + paths = append(paths, &baseConfigPath) + } + if compiler.InferFileFormat(celExpression) != compiler.Unspecified { + paths = append(paths, &celExpression) + } + return UpdateTestResourcesPaths(testResourcesDir, paths) +} + +// UpdateTestResourcesPaths updates the list of paths with their absolute paths as per their location +// in the testResourcesDir directory. This will allow the executable targets to locate and access the +// data dependencies needed to trigger the tests. +// For example: In case of Bazel, this method can be used to update the file paths with the corresponding +// location in the runfiles directory tree: +// +// UpdateTestResourcesPaths(os.Getenv("RUNFILES_DIR"), ) +func UpdateTestResourcesPaths(testResourcesDir string, paths []*string) error { + if testResourcesDir == "" { + return nil + } + err := filepath.Walk(testResourcesDir, func(p string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if info.IsDir() { + return nil + } + relPath, err := filepath.Rel(testResourcesDir, p) + if err != nil { + return err + } + for _, path := range paths { + if strings.Contains(relPath, *path) { + *path = p + } + } + return nil + }) + return err +} + // TestRunnerOption is used to configure the following attributes of the Test Runner: // - set the Compiler // - add Input Expressions @@ -111,10 +167,13 @@ func TriggerTests(t *testing.T, testRunnerOpts ...TestRunnerOption) { // File Descriptor Set Path of the test runner. // - Test expression - The `cel_expr` flag is used to populate the test expressions which need to be // evaluated by the test runner. -func TestRunnerOptionsFromFlags(testRunnerOpts []TestRunnerOption, testCompilerOpts ...any) TestRunnerOption { +func TestRunnerOptionsFromFlags(testResourcesDir string, testRunnerOpts []TestRunnerOption, testCompilerOpts ...any) TestRunnerOption { if !flag.Parsed() { flag.Parse() } + if err := updateRunfilesPathForFlags(testResourcesDir); err != nil { + return nil + } return func(tr *TestRunner) (*TestRunner, error) { opts := []TestRunnerOption{ testRunnerCompilerFromFlags(testCompilerOpts...), @@ -358,9 +417,18 @@ func fileDescriptorSet(path string) (*descpb.FileDescriptorSet, error) { return fds, nil } +// Program represents the result of creating CEL programs for the configured expressions in the +// test runner. It encompasses the following: +// - CELProgram - the evaluable CEL program +// - PolicyMetadata - the metadata map obtained while creating the CEL AST from the expression +type Program struct { + cel.Program + PolicyMetadata map[string]any +} + // Programs creates a list of CEL programs from the input expressions configured in the test runner // using the provided program options. -func (tr *TestRunner) Programs(t *testing.T, opts ...cel.ProgramOption) ([]cel.Program, error) { +func (tr *TestRunner) Programs(t *testing.T, opts ...cel.ProgramOption) ([]Program, error) { t.Helper() if tr.Compiler == nil { return nil, fmt.Errorf("compiler is not set") @@ -369,10 +437,9 @@ func (tr *TestRunner) Programs(t *testing.T, opts ...cel.ProgramOption) ([]cel.P if err != nil { return nil, err } - var programs []cel.Program + programs := make([]Program, 0, len(tr.Expressions)) for _, expr := range tr.Expressions { - // TODO: propagate metadata map along with the program instance as a struct. - ast, _, err := expr.CreateAST(tr.Compiler) + ast, policyMetadata, err := expr.CreateAST(tr.Compiler) if err != nil { if strings.Contains(err.Error(), "invalid file extension") || strings.Contains(err.Error(), "invalid raw expression") { @@ -384,7 +451,10 @@ func (tr *TestRunner) Programs(t *testing.T, opts ...cel.ProgramOption) ([]cel.P if err != nil { return nil, err } - programs = append(programs, prg) + programs = append(programs, Program{ + Program: prg, + PolicyMetadata: policyMetadata, + }) } return programs, nil } @@ -730,13 +800,16 @@ func (tr *TestRunner) createResultMatcher(t *testing.T, testOutput *test.Output) // ExecuteTest executes the test case against the provided list of programs and returns an error if // the test fails. -func (tr *TestRunner) ExecuteTest(t *testing.T, programs []cel.Program, test *Test) error { +func (tr *TestRunner) ExecuteTest(t *testing.T, programs []Program, test *Test) error { t.Helper() if tr.Compiler == nil { return fmt.Errorf("compiler is not set") } - for _, program := range programs { - out, _, err := program.Eval(test.input) + for _, pr := range programs { + if pr.Program == nil { + return fmt.Errorf("CEL program not set") + } + out, _, err := pr.Eval(test.input) if testResult := test.resultMatcher(out, err); !testResult.Success { return fmt.Errorf("test: %s \n wanted: %v \n failed: %v", test.name, testResult.Wanted, testResult.Error) } diff --git a/tools/celtest/test_runner_test.go b/tools/celtest/test_runner_test.go index c434d6b26..29c9e07d9 100644 --- a/tools/celtest/test_runner_test.go +++ b/tools/celtest/test_runner_test.go @@ -187,10 +187,9 @@ func fnEnvOption() cel.EnvOption { // TestTriggerTests tests different scenarios of the TriggerTestsFromCompiler function. func TestTriggerTests(t *testing.T) { for _, tc := range setupTests() { - tc := tc t.Run(tc.name, func(t *testing.T) { var testOpts []TestRunnerOption - var compileOpts []any = make([]any, 0, len(tc.opts)+2) + compileOpts := make([]any, 0, len(tc.opts)+2) for _, opt := range tc.opts { compileOpts = append(compileOpts, opt) } diff --git a/tools/compiler/compiler.go b/tools/compiler/compiler.go index 272df2926..b8fd70ea3 100644 --- a/tools/compiler/compiler.go +++ b/tools/compiler/compiler.go @@ -200,7 +200,7 @@ func InferFileFormat(path string) FileFormat { return TextProto case ".yaml": return TextYAML - case ".binarypb", ".fds": + case ".binarypb", ".fds", ".pb": return BinaryProto case ".cel": return CELString diff --git a/tools/compiler/compiler_test.go b/tools/compiler/compiler_test.go index 4bbad785e..5513919e3 100644 --- a/tools/compiler/compiler_test.go +++ b/tools/compiler/compiler_test.go @@ -72,7 +72,7 @@ func TestEnvironmentFileCompareTextprotoAndYAML(t *testing.T) { for i, v := range protoConfig.Variables { for j, p := range v.TypeDesc.Params { if p.TypeName == "google.protobuf.Any" && - config.Variables[i].TypeDesc.Params[j].TypeName == "dyn" { + config.Variables[i].TypeDesc.Params[j].TypeName == "dyn" { p.TypeName = "dyn" } } diff --git a/tools/go.mod b/tools/go.mod index ea3a3dbf8..34459ca73 100644 --- a/tools/go.mod +++ b/tools/go.mod @@ -4,6 +4,7 @@ go 1.23.0 require ( cel.dev/expr v0.23.1 + github.com/bazelbuild/rules_go v0.54.0 github.com/google/cel-go v0.22.0 github.com/google/cel-go/policy v0.0.0-20250311174852-f5ea07b389a1 github.com/google/go-cmp v0.6.0 diff --git a/tools/go.sum b/tools/go.sum index 8e2a54ac3..482b98c2a 100644 --- a/tools/go.sum +++ b/tools/go.sum @@ -2,6 +2,8 @@ cel.dev/expr v0.23.1 h1:K4KOtPCJQjVggkARsjG9RWXP6O4R73aHeJMa/dmCQQg= cel.dev/expr v0.23.1/go.mod h1:hLPLo1W4QUmuYdA72RBX06QTs6MXw941piREPl3Yfiw= github.com/antlr4-go/antlr/v4 v4.13.1 h1:SqQKkuVZ+zWkMMNkjy5FZe5mr5WURWnlpmOuzYWrPrQ= github.com/antlr4-go/antlr/v4 v4.13.1/go.mod h1:GKmUxMtwp6ZgGwZSva4eWPC5mS6vUAmOABFgjdkM7Nw= +github.com/bazelbuild/rules_go v0.54.0 h1:D9aCU7j5rdRxg2rXOZX5zHZ395XC0KbgC4rnyaQ3ofM= +github.com/bazelbuild/rules_go v0.54.0/go.mod h1:T90Gpyq4HDFlsrvtQa2CBdHNJ2P4rAu/uUTmQbanzf0= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= From 83ab6417de80a9bc63f94eb167f2e5b9d6082c08 Mon Sep 17 00:00:00 2001 From: ChinmayMadeshi Date: Fri, 23 May 2025 22:34:50 +0000 Subject: [PATCH 45/46] Create an util method to convert rpc status to eval status (#1178) * Create an util method to convert rpc status to eval status * Formatted the Go files. * Formatted the WORKSPACE file. * Add failing conformance test to SKIP_TESTS. * Added RefValue to ExprValue converter. * Update Module bazel. * Updated go.mod and go.sum. * Removed gRpc status dependencies from io and io_test. * formatted the files. * Followed the convention of returning the v1alpha1 format. * Updated the documentation. * Remove unnecessary dependencies. --- .gitignore | 1 + WORKSPACE | 9 +- cel/BUILD.bazel | 6 +- cel/io.go | 49 +++++ cel/io_test.go | 53 +++++ conformance/BUILD.bazel | 3 +- conformance/go.mod | 2 +- conformance/go.sum | 4 +- go.mod | 2 +- go.sum | 4 +- vendor/cel.dev/expr/eval.pb.go | 361 ++++++++++++++++----------------- vendor/modules.txt | 2 +- 12 files changed, 300 insertions(+), 196 deletions(-) diff --git a/.gitignore b/.gitignore index e4d06d5f0..2b38b6db6 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ proto/checked.pb.go proto/syntax.pb.go *~ MODULE.bazel.lock +.ijwb/ \ No newline at end of file diff --git a/WORKSPACE b/WORKSPACE index 956f16a4d..08a87a05b 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -45,10 +45,10 @@ http_archive( urls = ["https://github.com/protocolbuffers/protobuf/archive/v3.21.5.zip"], ) -load("@io_bazel_rules_go//go:deps.bzl", "go_register_toolchains", "go_rules_dependencies") load("@bazel_gazelle//:deps.bzl", "gazelle_dependencies", "go_repository") load("@com_google_googleapis//:repository_rules.bzl", "switched_rules_by_language") load("@com_google_protobuf//:protobuf_deps.bzl", "protobuf_deps") +load("@io_bazel_rules_go//go:deps.bzl", "go_register_toolchains", "go_rules_dependencies") switched_rules_by_language( name = "com_google_googleapis_imports", @@ -101,8 +101,8 @@ go_repository( go_repository( name = "dev_cel_expr", importpath = "cel.dev/expr", - sum = "h1:K4KOtPCJQjVggkARsjG9RWXP6O4R73aHeJMa/dmCQQg=", - version = "v0.23.1", + sum = "h1:56OvJKSH3hDGL0ml5uSxZmz3/3Pq4tJ+fb1unVLAFcY=", + version = "v0.24.0", ) # local_repository( @@ -158,12 +158,15 @@ go_register_toolchains(version = "1.22.0") gazelle_dependencies() load("@rules_proto//proto:repositories.bzl", "rules_proto_dependencies") + rules_proto_dependencies() load("@rules_proto//proto:setup.bzl", "rules_proto_setup") + rules_proto_setup() load("@rules_proto//proto:toolchains.bzl", "rules_proto_toolchains") + rules_proto_toolchains() protobuf_deps() diff --git a/cel/BUILD.bazel b/cel/BUILD.bazel index 4a0425a8e..c12e4904d 100644 --- a/cel/BUILD.bazel +++ b/cel/BUILD.bazel @@ -11,8 +11,8 @@ go_library( "decls.go", "env.go", "folding.go", - "io.go", "inlining.go", + "io.go", "library.go", "macro.go", "optimizer.go", @@ -64,8 +64,8 @@ go_test( "decls_test.go", "env_test.go", "folding_test.go", - "io_test.go", "inlining_test.go", + "io_test.go", "optimizer_test.go", "prompt_test.go", "validator_test.go", @@ -90,8 +90,8 @@ go_test( "//test/proto2pb:go_default_library", "//test/proto3pb:go_default_library", "@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library", - "@org_golang_google_protobuf//proto:go_default_library", "@org_golang_google_protobuf//encoding/prototext:go_default_library", + "@org_golang_google_protobuf//proto:go_default_library", "@org_golang_google_protobuf//types/known/structpb:go_default_library", "@org_golang_google_protobuf//types/known/wrapperspb:go_default_library", ], diff --git a/cel/io.go b/cel/io.go index 7b1a4bed2..2e611228d 100644 --- a/cel/io.go +++ b/cel/io.go @@ -126,6 +126,55 @@ func ValueAsAlphaProto(res ref.Val) (*exprpb.Value, error) { return alpha, err } +// RefValToExprValue converts between ref.Val and google.api.expr.v1alpha1.ExprValue. +// The result ExprValue is the serialized proto form. +func RefValToExprValue(res ref.Val) (*exprpb.ExprValue, error) { + return ExprValueAsAlphaProto(res) +} + +// ExprValueAsAlphaProto converts between ref.Val and google.api.expr.v1alpha1.ExprValue. +// The result ExprValue is the serialized proto form. +func ExprValueAsAlphaProto(res ref.Val) (*exprpb.ExprValue, error) { + canonical, err := ExprValueAsProto(res) + if err != nil { + return nil, err + } + alpha := &exprpb.ExprValue{} + err = convertProto(canonical, alpha) + return alpha, err +} + +// ExprValueAsProto converts between ref.Val and cel.expr.ExprValue. +// The result ExprValue is the serialized proto form. +func ExprValueAsProto(res ref.Val) (*celpb.ExprValue, error) { + switch res := res.(type) { + case *types.Unknown: + return &celpb.ExprValue{ + Kind: &celpb.ExprValue_Unknown{ + Unknown: &celpb.UnknownSet{ + Exprs: res.IDs(), + }, + }}, nil + case *types.Err: + return &celpb.ExprValue{ + Kind: &celpb.ExprValue_Error{ + Error: &celpb.ErrorSet{ + // Keeping the error code as UNKNOWN since there's no error codes associated with + // Cel-Go runtime errors. + Errors: []*celpb.Status{{Code: 2, Message: res.Error()}}, + }, + }, + }, nil + default: + val, err := ValueAsProto(res) + if err != nil { + return nil, err + } + return &celpb.ExprValue{ + Kind: &celpb.ExprValue_Value{Value: val}}, nil + } +} + // ValueAsProto converts between ref.Val and cel.expr.Value. // The result Value is the serialized proto form. The ref.Val must not be error or unknown. func ValueAsProto(res ref.Val) (*celpb.Value, error) { diff --git a/cel/io_test.go b/cel/io_test.go index d2b864495..9da7ce43d 100644 --- a/cel/io_test.go +++ b/cel/io_test.go @@ -26,6 +26,7 @@ import ( celast "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/operators" "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" proto3pb "github.com/google/cel-go/test/proto3pb" exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" @@ -193,6 +194,58 @@ func TestExprToString(t *testing.T) { } } +func TestRefValToExprValue(t *testing.T) { + tests := []struct { + name string + refVal ref.Val + expectError bool + }{ + { + name: "unknown value", + refVal: types.NewUnknown(1, nil), + expectError: false, + }, + { + name: "error value", + refVal: types.NewErr("test error"), + expectError: false, + }, + { + name: "bool value", + refVal: types.Bool(true), + expectError: false, + }, + { + name: "string value", + refVal: types.String("test"), + expectError: false, + }, + { + name: "int value", + refVal: types.Int(1), + expectError: false, + }, + } + for _, tst := range tests { + tc := tst + t.Run(tc.name, func(t *testing.T) { + exprVal, err := ExprValueAsProto(tc.refVal) + if tc.expectError { + if err == nil { + t.Errorf("RefValToExprValue(%v) expected error, got %v", tc.refVal, exprVal) + } + } else { + if err != nil { + t.Errorf("RefValToExprValue(%v) failed with error: %v", tc.refVal, err) + } + if exprVal == nil { + t.Errorf("RefValToExprValue(%v) expected value, got nil", tc.refVal) + } + } + }) + } +} + func TestAstToStringNil(t *testing.T) { expr, err := AstToString(nil) if err == nil || !strings.Contains(err.Error(), "unsupported expr") { diff --git a/conformance/BUILD.bazel b/conformance/BUILD.bazel index 50ec9c8d3..cdb8d6dfe 100644 --- a/conformance/BUILD.bazel +++ b/conformance/BUILD.bazel @@ -48,6 +48,7 @@ _TESTS_TO_SKIP = [ "fields/qualified_identifier_resolution/map_value_repeat_key_heterogeneous", "macros/map/map_extract_keys", "timestamps/duration_converters/get_milliseconds", + "optionals/optionals/map_optional_select_has", # Temporarily failing tests, need a spec update "string_ext/value_errors/indexof_out_of_range,lastindexof_out_of_range", @@ -84,9 +85,9 @@ go_test( "@com_github_google_go_cmp//cmp:go_default_library", "@dev_cel_expr//:expr", "@dev_cel_expr//conformance:go_default_library", - "@dev_cel_expr//conformance/test:go_default_library", "@dev_cel_expr//conformance/proto2:go_default_library", "@dev_cel_expr//conformance/proto3:go_default_library", + "@dev_cel_expr//conformance/test:go_default_library", "@io_bazel_rules_go//go/runfiles", "@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library", "@org_golang_google_protobuf//encoding/prototext:go_default_library", diff --git a/conformance/go.mod b/conformance/go.mod index 92d2ca476..f03ad605d 100644 --- a/conformance/go.mod +++ b/conformance/go.mod @@ -3,7 +3,7 @@ module github.com/google/cel-go/conformance go 1.22.0 require ( - cel.dev/expr v0.23.1 + cel.dev/expr v0.24.0 github.com/bazelbuild/rules_go v0.49.0 github.com/google/cel-go v0.21.0 github.com/google/go-cmp v0.6.0 diff --git a/conformance/go.sum b/conformance/go.sum index 95ac73a52..2bd5da35d 100644 --- a/conformance/go.sum +++ b/conformance/go.sum @@ -1,5 +1,5 @@ -cel.dev/expr v0.23.1 h1:K4KOtPCJQjVggkARsjG9RWXP6O4R73aHeJMa/dmCQQg= -cel.dev/expr v0.23.1/go.mod h1:hLPLo1W4QUmuYdA72RBX06QTs6MXw941piREPl3Yfiw= +cel.dev/expr v0.24.0 h1:56OvJKSH3hDGL0ml5uSxZmz3/3Pq4tJ+fb1unVLAFcY= +cel.dev/expr v0.24.0/go.mod h1:hLPLo1W4QUmuYdA72RBX06QTs6MXw941piREPl3Yfiw= github.com/antlr4-go/antlr/v4 v4.13.0 h1:lxCg3LAv+EUK6t1i0y1V6/SLeUi0eKEKdhQAlS8TVTI= github.com/antlr4-go/antlr/v4 v4.13.0/go.mod h1:pfChB/xh/Unjila75QW7+VU4TSnWnnk9UTnmpPaOR2g= github.com/bazelbuild/rules_go v0.49.0 h1:5vCbuvy8Q11g41lseGJDc5vxhDjJtfxr6nM/IC4VmqM= diff --git a/go.mod b/go.mod index 914c1ec28..e4ee536aa 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.22.0 toolchain go1.23.0 require ( - cel.dev/expr v0.23.1 + cel.dev/expr v0.24.0 github.com/antlr4-go/antlr/v4 v4.13.0 github.com/stoewer/go-strcase v1.2.0 google.golang.org/genproto/googleapis/api v0.0.0-20240826202546-f6391c0de4c7 diff --git a/go.sum b/go.sum index 23fe2170a..0ed868a96 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -cel.dev/expr v0.23.1 h1:K4KOtPCJQjVggkARsjG9RWXP6O4R73aHeJMa/dmCQQg= -cel.dev/expr v0.23.1/go.mod h1:hLPLo1W4QUmuYdA72RBX06QTs6MXw941piREPl3Yfiw= +cel.dev/expr v0.24.0 h1:56OvJKSH3hDGL0ml5uSxZmz3/3Pq4tJ+fb1unVLAFcY= +cel.dev/expr v0.24.0/go.mod h1:hLPLo1W4QUmuYdA72RBX06QTs6MXw941piREPl3Yfiw= github.com/antlr4-go/antlr/v4 v4.13.0 h1:lxCg3LAv+EUK6t1i0y1V6/SLeUi0eKEKdhQAlS8TVTI= github.com/antlr4-go/antlr/v4 v4.13.0/go.mod h1:pfChB/xh/Unjila75QW7+VU4TSnWnnk9UTnmpPaOR2g= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= diff --git a/vendor/cel.dev/expr/eval.pb.go b/vendor/cel.dev/expr/eval.pb.go index 8f651f9cc..a7aae0900 100644 --- a/vendor/cel.dev/expr/eval.pb.go +++ b/vendor/cel.dev/expr/eval.pb.go @@ -1,15 +1,15 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.28.1 -// protoc v3.21.5 +// protoc-gen-go v1.36.3 +// protoc v5.27.1 // source: cel/expr/eval.proto package expr import ( - status "google.golang.org/genproto/googleapis/rpc/status" protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" + anypb "google.golang.org/protobuf/types/known/anypb" reflect "reflect" sync "sync" ) @@ -22,21 +22,18 @@ const ( ) type EvalState struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + Values []*ExprValue `protobuf:"bytes,1,rep,name=values,proto3" json:"values,omitempty"` + Results []*EvalState_Result `protobuf:"bytes,3,rep,name=results,proto3" json:"results,omitempty"` unknownFields protoimpl.UnknownFields - - Values []*ExprValue `protobuf:"bytes,1,rep,name=values,proto3" json:"values,omitempty"` - Results []*EvalState_Result `protobuf:"bytes,3,rep,name=results,proto3" json:"results,omitempty"` + sizeCache protoimpl.SizeCache } func (x *EvalState) Reset() { *x = EvalState{} - if protoimpl.UnsafeEnabled { - mi := &file_cel_expr_eval_proto_msgTypes[0] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_cel_expr_eval_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *EvalState) String() string { @@ -47,7 +44,7 @@ func (*EvalState) ProtoMessage() {} func (x *EvalState) ProtoReflect() protoreflect.Message { mi := &file_cel_expr_eval_proto_msgTypes[0] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -77,25 +74,22 @@ func (x *EvalState) GetResults() []*EvalState_Result { } type ExprValue struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - // Types that are assignable to Kind: + state protoimpl.MessageState `protogen:"open.v1"` + // Types that are valid to be assigned to Kind: // // *ExprValue_Value // *ExprValue_Error // *ExprValue_Unknown - Kind isExprValue_Kind `protobuf_oneof:"kind"` + Kind isExprValue_Kind `protobuf_oneof:"kind"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *ExprValue) Reset() { *x = ExprValue{} - if protoimpl.UnsafeEnabled { - mi := &file_cel_expr_eval_proto_msgTypes[1] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_cel_expr_eval_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *ExprValue) String() string { @@ -106,7 +100,7 @@ func (*ExprValue) ProtoMessage() {} func (x *ExprValue) ProtoReflect() protoreflect.Message { mi := &file_cel_expr_eval_proto_msgTypes[1] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -121,30 +115,36 @@ func (*ExprValue) Descriptor() ([]byte, []int) { return file_cel_expr_eval_proto_rawDescGZIP(), []int{1} } -func (m *ExprValue) GetKind() isExprValue_Kind { - if m != nil { - return m.Kind +func (x *ExprValue) GetKind() isExprValue_Kind { + if x != nil { + return x.Kind } return nil } func (x *ExprValue) GetValue() *Value { - if x, ok := x.GetKind().(*ExprValue_Value); ok { - return x.Value + if x != nil { + if x, ok := x.Kind.(*ExprValue_Value); ok { + return x.Value + } } return nil } func (x *ExprValue) GetError() *ErrorSet { - if x, ok := x.GetKind().(*ExprValue_Error); ok { - return x.Error + if x != nil { + if x, ok := x.Kind.(*ExprValue_Error); ok { + return x.Error + } } return nil } func (x *ExprValue) GetUnknown() *UnknownSet { - if x, ok := x.GetKind().(*ExprValue_Unknown); ok { - return x.Unknown + if x != nil { + if x, ok := x.Kind.(*ExprValue_Unknown); ok { + return x.Unknown + } } return nil } @@ -172,20 +172,17 @@ func (*ExprValue_Error) isExprValue_Kind() {} func (*ExprValue_Unknown) isExprValue_Kind() {} type ErrorSet struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + Errors []*Status `protobuf:"bytes,1,rep,name=errors,proto3" json:"errors,omitempty"` unknownFields protoimpl.UnknownFields - - Errors []*status.Status `protobuf:"bytes,1,rep,name=errors,proto3" json:"errors,omitempty"` + sizeCache protoimpl.SizeCache } func (x *ErrorSet) Reset() { *x = ErrorSet{} - if protoimpl.UnsafeEnabled { - mi := &file_cel_expr_eval_proto_msgTypes[2] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_cel_expr_eval_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *ErrorSet) String() string { @@ -196,7 +193,7 @@ func (*ErrorSet) ProtoMessage() {} func (x *ErrorSet) ProtoReflect() protoreflect.Message { mi := &file_cel_expr_eval_proto_msgTypes[2] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -211,28 +208,85 @@ func (*ErrorSet) Descriptor() ([]byte, []int) { return file_cel_expr_eval_proto_rawDescGZIP(), []int{2} } -func (x *ErrorSet) GetErrors() []*status.Status { +func (x *ErrorSet) GetErrors() []*Status { if x != nil { return x.Errors } return nil } -type UnknownSet struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache +type Status struct { + state protoimpl.MessageState `protogen:"open.v1"` + Code int32 `protobuf:"varint,1,opt,name=code,proto3" json:"code,omitempty"` + Message string `protobuf:"bytes,2,opt,name=message,proto3" json:"message,omitempty"` + Details []*anypb.Any `protobuf:"bytes,3,rep,name=details,proto3" json:"details,omitempty"` unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} - Exprs []int64 `protobuf:"varint,1,rep,packed,name=exprs,proto3" json:"exprs,omitempty"` +func (x *Status) Reset() { + *x = Status{} + mi := &file_cel_expr_eval_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } -func (x *UnknownSet) Reset() { - *x = UnknownSet{} - if protoimpl.UnsafeEnabled { - mi := &file_cel_expr_eval_proto_msgTypes[3] +func (x *Status) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Status) ProtoMessage() {} + +func (x *Status) ProtoReflect() protoreflect.Message { + mi := &file_cel_expr_eval_proto_msgTypes[3] + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Status.ProtoReflect.Descriptor instead. +func (*Status) Descriptor() ([]byte, []int) { + return file_cel_expr_eval_proto_rawDescGZIP(), []int{3} +} + +func (x *Status) GetCode() int32 { + if x != nil { + return x.Code } + return 0 +} + +func (x *Status) GetMessage() string { + if x != nil { + return x.Message + } + return "" +} + +func (x *Status) GetDetails() []*anypb.Any { + if x != nil { + return x.Details + } + return nil +} + +type UnknownSet struct { + state protoimpl.MessageState `protogen:"open.v1"` + Exprs []int64 `protobuf:"varint,1,rep,packed,name=exprs,proto3" json:"exprs,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *UnknownSet) Reset() { + *x = UnknownSet{} + mi := &file_cel_expr_eval_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *UnknownSet) String() string { @@ -242,8 +296,8 @@ func (x *UnknownSet) String() string { func (*UnknownSet) ProtoMessage() {} func (x *UnknownSet) ProtoReflect() protoreflect.Message { - mi := &file_cel_expr_eval_proto_msgTypes[3] - if protoimpl.UnsafeEnabled && x != nil { + mi := &file_cel_expr_eval_proto_msgTypes[4] + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -255,7 +309,7 @@ func (x *UnknownSet) ProtoReflect() protoreflect.Message { // Deprecated: Use UnknownSet.ProtoReflect.Descriptor instead. func (*UnknownSet) Descriptor() ([]byte, []int) { - return file_cel_expr_eval_proto_rawDescGZIP(), []int{3} + return file_cel_expr_eval_proto_rawDescGZIP(), []int{4} } func (x *UnknownSet) GetExprs() []int64 { @@ -266,21 +320,18 @@ func (x *UnknownSet) GetExprs() []int64 { } type EvalState_Result struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + Expr int64 `protobuf:"varint,1,opt,name=expr,proto3" json:"expr,omitempty"` + Value int64 `protobuf:"varint,2,opt,name=value,proto3" json:"value,omitempty"` unknownFields protoimpl.UnknownFields - - Expr int64 `protobuf:"varint,1,opt,name=expr,proto3" json:"expr,omitempty"` - Value int64 `protobuf:"varint,2,opt,name=value,proto3" json:"value,omitempty"` + sizeCache protoimpl.SizeCache } func (x *EvalState_Result) Reset() { *x = EvalState_Result{} - if protoimpl.UnsafeEnabled { - mi := &file_cel_expr_eval_proto_msgTypes[4] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_cel_expr_eval_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *EvalState_Result) String() string { @@ -290,8 +341,8 @@ func (x *EvalState_Result) String() string { func (*EvalState_Result) ProtoMessage() {} func (x *EvalState_Result) ProtoReflect() protoreflect.Message { - mi := &file_cel_expr_eval_proto_msgTypes[4] - if protoimpl.UnsafeEnabled && x != nil { + mi := &file_cel_expr_eval_proto_msgTypes[5] + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -325,39 +376,45 @@ var File_cel_expr_eval_proto protoreflect.FileDescriptor var file_cel_expr_eval_proto_rawDesc = []byte{ 0x0a, 0x13, 0x63, 0x65, 0x6c, 0x2f, 0x65, 0x78, 0x70, 0x72, 0x2f, 0x65, 0x76, 0x61, 0x6c, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x08, 0x63, 0x65, 0x6c, 0x2e, 0x65, 0x78, 0x70, 0x72, 0x1a, - 0x14, 0x63, 0x65, 0x6c, 0x2f, 0x65, 0x78, 0x70, 0x72, 0x2f, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x2e, - 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x17, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x72, 0x70, - 0x63, 0x2f, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xa2, - 0x01, 0x0a, 0x09, 0x45, 0x76, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x2b, 0x0a, 0x06, - 0x76, 0x61, 0x6c, 0x75, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x63, - 0x65, 0x6c, 0x2e, 0x65, 0x78, 0x70, 0x72, 0x2e, 0x45, 0x78, 0x70, 0x72, 0x56, 0x61, 0x6c, 0x75, - 0x65, 0x52, 0x06, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x73, 0x12, 0x34, 0x0a, 0x07, 0x72, 0x65, 0x73, - 0x75, 0x6c, 0x74, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x63, 0x65, 0x6c, - 0x2e, 0x65, 0x78, 0x70, 0x72, 0x2e, 0x45, 0x76, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x2e, - 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x52, 0x07, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x73, 0x1a, - 0x32, 0x0a, 0x06, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x65, 0x78, 0x70, - 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x65, 0x78, 0x70, 0x72, 0x12, 0x14, 0x0a, - 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x05, 0x76, 0x61, - 0x6c, 0x75, 0x65, 0x22, 0x9a, 0x01, 0x0a, 0x09, 0x45, 0x78, 0x70, 0x72, 0x56, 0x61, 0x6c, 0x75, - 0x65, 0x12, 0x27, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, - 0x32, 0x0f, 0x2e, 0x63, 0x65, 0x6c, 0x2e, 0x65, 0x78, 0x70, 0x72, 0x2e, 0x56, 0x61, 0x6c, 0x75, - 0x65, 0x48, 0x00, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x12, 0x2a, 0x0a, 0x05, 0x65, 0x72, - 0x72, 0x6f, 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x63, 0x65, 0x6c, 0x2e, - 0x65, 0x78, 0x70, 0x72, 0x2e, 0x45, 0x72, 0x72, 0x6f, 0x72, 0x53, 0x65, 0x74, 0x48, 0x00, 0x52, - 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x12, 0x30, 0x0a, 0x07, 0x75, 0x6e, 0x6b, 0x6e, 0x6f, 0x77, - 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x63, 0x65, 0x6c, 0x2e, 0x65, 0x78, - 0x70, 0x72, 0x2e, 0x55, 0x6e, 0x6b, 0x6e, 0x6f, 0x77, 0x6e, 0x53, 0x65, 0x74, 0x48, 0x00, 0x52, - 0x07, 0x75, 0x6e, 0x6b, 0x6e, 0x6f, 0x77, 0x6e, 0x42, 0x06, 0x0a, 0x04, 0x6b, 0x69, 0x6e, 0x64, - 0x22, 0x36, 0x0a, 0x08, 0x45, 0x72, 0x72, 0x6f, 0x72, 0x53, 0x65, 0x74, 0x12, 0x2a, 0x0a, 0x06, - 0x65, 0x72, 0x72, 0x6f, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x67, - 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x72, 0x70, 0x63, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, - 0x52, 0x06, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x73, 0x22, 0x22, 0x0a, 0x0a, 0x55, 0x6e, 0x6b, 0x6e, - 0x6f, 0x77, 0x6e, 0x53, 0x65, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x78, 0x70, 0x72, 0x73, 0x18, - 0x01, 0x20, 0x03, 0x28, 0x03, 0x52, 0x05, 0x65, 0x78, 0x70, 0x72, 0x73, 0x42, 0x2c, 0x0a, 0x0c, - 0x64, 0x65, 0x76, 0x2e, 0x63, 0x65, 0x6c, 0x2e, 0x65, 0x78, 0x70, 0x72, 0x42, 0x09, 0x45, 0x76, - 0x61, 0x6c, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x0c, 0x63, 0x65, 0x6c, 0x2e, 0x64, - 0x65, 0x76, 0x2f, 0x65, 0x78, 0x70, 0x72, 0xf8, 0x01, 0x01, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x33, + 0x19, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, + 0x2f, 0x61, 0x6e, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x14, 0x63, 0x65, 0x6c, 0x2f, + 0x65, 0x78, 0x70, 0x72, 0x2f, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x22, 0xa2, 0x01, 0x0a, 0x09, 0x45, 0x76, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x2b, + 0x0a, 0x06, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x13, + 0x2e, 0x63, 0x65, 0x6c, 0x2e, 0x65, 0x78, 0x70, 0x72, 0x2e, 0x45, 0x78, 0x70, 0x72, 0x56, 0x61, + 0x6c, 0x75, 0x65, 0x52, 0x06, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x73, 0x12, 0x34, 0x0a, 0x07, 0x72, + 0x65, 0x73, 0x75, 0x6c, 0x74, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x63, + 0x65, 0x6c, 0x2e, 0x65, 0x78, 0x70, 0x72, 0x2e, 0x45, 0x76, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, + 0x65, 0x2e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x52, 0x07, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, + 0x73, 0x1a, 0x32, 0x0a, 0x06, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x65, + 0x78, 0x70, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x65, 0x78, 0x70, 0x72, 0x12, + 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x05, + 0x76, 0x61, 0x6c, 0x75, 0x65, 0x22, 0x9a, 0x01, 0x0a, 0x09, 0x45, 0x78, 0x70, 0x72, 0x56, 0x61, + 0x6c, 0x75, 0x65, 0x12, 0x27, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x0f, 0x2e, 0x63, 0x65, 0x6c, 0x2e, 0x65, 0x78, 0x70, 0x72, 0x2e, 0x56, 0x61, + 0x6c, 0x75, 0x65, 0x48, 0x00, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x12, 0x2a, 0x0a, 0x05, + 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x63, 0x65, + 0x6c, 0x2e, 0x65, 0x78, 0x70, 0x72, 0x2e, 0x45, 0x72, 0x72, 0x6f, 0x72, 0x53, 0x65, 0x74, 0x48, + 0x00, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x12, 0x30, 0x0a, 0x07, 0x75, 0x6e, 0x6b, 0x6e, + 0x6f, 0x77, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x63, 0x65, 0x6c, 0x2e, + 0x65, 0x78, 0x70, 0x72, 0x2e, 0x55, 0x6e, 0x6b, 0x6e, 0x6f, 0x77, 0x6e, 0x53, 0x65, 0x74, 0x48, + 0x00, 0x52, 0x07, 0x75, 0x6e, 0x6b, 0x6e, 0x6f, 0x77, 0x6e, 0x42, 0x06, 0x0a, 0x04, 0x6b, 0x69, + 0x6e, 0x64, 0x22, 0x34, 0x0a, 0x08, 0x45, 0x72, 0x72, 0x6f, 0x72, 0x53, 0x65, 0x74, 0x12, 0x28, + 0x0a, 0x06, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x10, + 0x2e, 0x63, 0x65, 0x6c, 0x2e, 0x65, 0x78, 0x70, 0x72, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, + 0x52, 0x06, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x73, 0x22, 0x66, 0x0a, 0x06, 0x53, 0x74, 0x61, 0x74, + 0x75, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x63, 0x6f, 0x64, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, + 0x52, 0x04, 0x63, 0x6f, 0x64, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, + 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, + 0x12, 0x2e, 0x0a, 0x07, 0x64, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, + 0x0b, 0x32, 0x14, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x62, 0x75, 0x66, 0x2e, 0x41, 0x6e, 0x79, 0x52, 0x07, 0x64, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, + 0x22, 0x22, 0x0a, 0x0a, 0x55, 0x6e, 0x6b, 0x6e, 0x6f, 0x77, 0x6e, 0x53, 0x65, 0x74, 0x12, 0x14, + 0x0a, 0x05, 0x65, 0x78, 0x70, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x03, 0x52, 0x05, 0x65, + 0x78, 0x70, 0x72, 0x73, 0x42, 0x2c, 0x0a, 0x0c, 0x64, 0x65, 0x76, 0x2e, 0x63, 0x65, 0x6c, 0x2e, + 0x65, 0x78, 0x70, 0x72, 0x42, 0x09, 0x45, 0x76, 0x61, 0x6c, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, + 0x01, 0x5a, 0x0c, 0x63, 0x65, 0x6c, 0x2e, 0x64, 0x65, 0x76, 0x2f, 0x65, 0x78, 0x70, 0x72, 0xf8, + 0x01, 0x01, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -372,28 +429,30 @@ func file_cel_expr_eval_proto_rawDescGZIP() []byte { return file_cel_expr_eval_proto_rawDescData } -var file_cel_expr_eval_proto_msgTypes = make([]protoimpl.MessageInfo, 5) -var file_cel_expr_eval_proto_goTypes = []interface{}{ +var file_cel_expr_eval_proto_msgTypes = make([]protoimpl.MessageInfo, 6) +var file_cel_expr_eval_proto_goTypes = []any{ (*EvalState)(nil), // 0: cel.expr.EvalState (*ExprValue)(nil), // 1: cel.expr.ExprValue (*ErrorSet)(nil), // 2: cel.expr.ErrorSet - (*UnknownSet)(nil), // 3: cel.expr.UnknownSet - (*EvalState_Result)(nil), // 4: cel.expr.EvalState.Result - (*Value)(nil), // 5: cel.expr.Value - (*status.Status)(nil), // 6: google.rpc.Status + (*Status)(nil), // 3: cel.expr.Status + (*UnknownSet)(nil), // 4: cel.expr.UnknownSet + (*EvalState_Result)(nil), // 5: cel.expr.EvalState.Result + (*Value)(nil), // 6: cel.expr.Value + (*anypb.Any)(nil), // 7: google.protobuf.Any } var file_cel_expr_eval_proto_depIdxs = []int32{ 1, // 0: cel.expr.EvalState.values:type_name -> cel.expr.ExprValue - 4, // 1: cel.expr.EvalState.results:type_name -> cel.expr.EvalState.Result - 5, // 2: cel.expr.ExprValue.value:type_name -> cel.expr.Value + 5, // 1: cel.expr.EvalState.results:type_name -> cel.expr.EvalState.Result + 6, // 2: cel.expr.ExprValue.value:type_name -> cel.expr.Value 2, // 3: cel.expr.ExprValue.error:type_name -> cel.expr.ErrorSet - 3, // 4: cel.expr.ExprValue.unknown:type_name -> cel.expr.UnknownSet - 6, // 5: cel.expr.ErrorSet.errors:type_name -> google.rpc.Status - 6, // [6:6] is the sub-list for method output_type - 6, // [6:6] is the sub-list for method input_type - 6, // [6:6] is the sub-list for extension type_name - 6, // [6:6] is the sub-list for extension extendee - 0, // [0:6] is the sub-list for field type_name + 4, // 4: cel.expr.ExprValue.unknown:type_name -> cel.expr.UnknownSet + 3, // 5: cel.expr.ErrorSet.errors:type_name -> cel.expr.Status + 7, // 6: cel.expr.Status.details:type_name -> google.protobuf.Any + 7, // [7:7] is the sub-list for method output_type + 7, // [7:7] is the sub-list for method input_type + 7, // [7:7] is the sub-list for extension type_name + 7, // [7:7] is the sub-list for extension extendee + 0, // [0:7] is the sub-list for field type_name } func init() { file_cel_expr_eval_proto_init() } @@ -402,69 +461,7 @@ func file_cel_expr_eval_proto_init() { return } file_cel_expr_value_proto_init() - if !protoimpl.UnsafeEnabled { - file_cel_expr_eval_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*EvalState); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_cel_expr_eval_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ExprValue); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_cel_expr_eval_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ErrorSet); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_cel_expr_eval_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*UnknownSet); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_cel_expr_eval_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*EvalState_Result); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - } - file_cel_expr_eval_proto_msgTypes[1].OneofWrappers = []interface{}{ + file_cel_expr_eval_proto_msgTypes[1].OneofWrappers = []any{ (*ExprValue_Value)(nil), (*ExprValue_Error)(nil), (*ExprValue_Unknown)(nil), @@ -475,7 +472,7 @@ func file_cel_expr_eval_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_cel_expr_eval_proto_rawDesc, NumEnums: 0, - NumMessages: 5, + NumMessages: 6, NumExtensions: 0, NumServices: 0, }, diff --git a/vendor/modules.txt b/vendor/modules.txt index a34dce8d0..1615630be 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -1,4 +1,4 @@ -# cel.dev/expr v0.23.1 +# cel.dev/expr v0.24.0 ## explicit; go 1.22.0 cel.dev/expr # github.com/antlr4-go/antlr/v4 v4.13.0 From 746d711377ddc7e49686f6cce1346b4ed6659d11 Mon Sep 17 00:00:00 2001 From: l46kok Date: Tue, 27 May 2025 18:18:15 -0700 Subject: [PATCH 46/46] Fix container setting for cel test all types example in online REPL (#1182) --- .../web/src/app/reference_panel/reference-panel-component.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/repl/appengine/web/src/app/reference_panel/reference-panel-component.ts b/repl/appengine/web/src/app/reference_panel/reference-panel-component.ts index 1c68155e1..5257d5e36 100644 --- a/repl/appengine/web/src/app/reference_panel/reference-panel-component.ts +++ b/repl/appengine/web/src/app/reference_panel/reference-panel-component.ts @@ -188,7 +188,7 @@ const examples = new Map([ request: { commands: [ `%load_descriptors --pkg 'cel-spec-test-types'`, - `%option --container "google.api.expr.test.v1"`, + `%option --container "cel.expr.conformance"`, `%let pb3 = proto3.TestAllTypes{}`, `%let pb2 = proto2.TestAllTypes`, `pb3 == proto3.TestAllTypes{}`