From 483201a3722a26e593a2aa6c4da0474eb6342d02 Mon Sep 17 00:00:00 2001 From: Ariel Mashraki Date: Tue, 9 Apr 2024 17:33:46 +0300 Subject: [PATCH] entgql: support skipping tx-opening by operation or field name --- entgql/transaction.go | 42 ++++++++++++++++++--- entgql/transaction_test.go | 76 ++++++++++++++++++++++++++++++++++---- 2 files changed, 106 insertions(+), 12 deletions(-) diff --git a/entgql/transaction.go b/entgql/transaction.go index d3f0d42d3..ff9a81fdb 100644 --- a/entgql/transaction.go +++ b/entgql/transaction.go @@ -18,6 +18,7 @@ import ( "context" "database/sql/driver" "errors" + "slices" "sync" "github.com/99designs/gqlgen/graphql" @@ -39,8 +40,35 @@ func (f TxOpenerFunc) OpenTx(ctx context.Context) (context.Context, driver.Tx, e return f(ctx) } -// Transactioner for graphql mutations. -type Transactioner struct{ TxOpener } +type ( + // Transactioner for graphql mutations. + Transactioner struct { + TxOpener + SkipTxFunc + } + // SkipTxFunc allows skipping operations from + // running under a transaction. + SkipTxFunc func(*ast.OperationDefinition) bool +) + +// SkipOperations skips the given operation names from running +// under a transaction. +func SkipOperations(names ...string) SkipTxFunc { + return func(op *ast.OperationDefinition) bool { + return slices.Contains(names, op.Name) + } +} + +// SkipIfHasFields skips the operation has a mutation field +// with the given names. +func SkipIfHasFields(names ...string) SkipTxFunc { + return func(op *ast.OperationDefinition) bool { + return slices.ContainsFunc(op.SelectionSet, func(s ast.Selection) bool { + f, ok := s.(*ast.Field) + return ok && slices.Contains(names, f.Name) + }) + } +} var _ interface { graphql.HandlerExtension @@ -62,8 +90,8 @@ func (t Transactioner) Validate(graphql.ExecutableSchema) error { } // MutateOperationContext serializes field resolvers during mutations. -func (Transactioner) MutateOperationContext(_ context.Context, oc *graphql.OperationContext) *gqlerror.Error { - if op := oc.Operation; op != nil && op.Operation == ast.Mutation { +func (t Transactioner) MutateOperationContext(_ context.Context, oc *graphql.OperationContext) *gqlerror.Error { + if !t.skipTx(oc.Operation) { previous := oc.ResolverMiddleware var mu sync.Mutex oc.ResolverMiddleware = func(ctx context.Context, next graphql.Resolver) (interface{}, error) { @@ -77,7 +105,7 @@ func (Transactioner) MutateOperationContext(_ context.Context, oc *graphql.Opera // InterceptResponse runs graphql mutations under a transaction. func (t Transactioner) InterceptResponse(ctx context.Context, next graphql.ResponseHandler) *graphql.Response { - if op := graphql.GetOperationContext(ctx).Operation; op == nil || op.Operation != ast.Mutation { + if t.skipTx(graphql.GetOperationContext(ctx).Operation) { return next(ctx) } txCtx, tx, err := t.OpenTx(ctx) @@ -108,3 +136,7 @@ func (t Transactioner) InterceptResponse(ctx context.Context, next graphql.Respo } return rsp } + +func (t Transactioner) skipTx(op *ast.OperationDefinition) bool { + return op == nil || op.Operation != ast.Mutation || (t.SkipTxFunc != nil && t.SkipTxFunc(op)) +} diff --git a/entgql/transaction_test.go b/entgql/transaction_test.go index ea75e04e8..9bbacfabb 100644 --- a/entgql/transaction_test.go +++ b/entgql/transaction_test.go @@ -30,10 +30,10 @@ import ( ) func TestTransaction(t *testing.T) { - newServer := func(opener entgql.TxOpener) *testserver.TestServer { + newServer := func(opener entgql.TxOpener, skip entgql.SkipTxFunc) *testserver.TestServer { srv := testserver.New() srv.AddTransport(transport.POST{}) - srv.Use(entgql.Transactioner{TxOpener: opener}) + srv.Use(entgql.Transactioner{TxOpener: opener, SkipTxFunc: skip}) return srv } fwdCtx := func(ctx context.Context) context.Context { @@ -44,7 +44,7 @@ func TestTransaction(t *testing.T) { t.Parallel() var opener mocks.TxOpener defer opener.AssertExpectations(t) - srv := newServer(&opener) + srv := newServer(&opener, nil) c := client.New(srv) err := c.Post(`query { name }`, &struct{ Name string }{}) @@ -65,7 +65,7 @@ func TestTransaction(t *testing.T) { Once() defer opener.AssertExpectations(t) - srv := newServer(&opener) + srv := newServer(&opener, nil) srv.AroundResponses(func(context.Context, graphql.ResponseHandler) *graphql.Response { return &graphql.Response{Data: []byte(`{"name":"test"}`)} }) @@ -74,6 +74,68 @@ func TestTransaction(t *testing.T) { err := c.Post(`mutation { name }`, &struct{ Name string }{}) require.NoError(t, err) }) + + t.Run("SkipOperation", func(t *testing.T) { + var ( + tx mocks.Tx + opener mocks.TxOpener + ) + tx.On("Commit"). + Return(nil). + Once() + defer tx.AssertExpectations(t) + + srv := newServer(&opener, entgql.SkipOperations("skipped")) + srv.AroundResponses(func(context.Context, graphql.ResponseHandler) *graphql.Response { + return &graphql.Response{Data: []byte(`{"name":"test"}`)} + }) + + c := client.New(srv) + err := c.Post(`mutation skipped { name }`, &struct{ Name string }{}) + require.NoError(t, err) + opener.AssertExpectations(t) + + opener.On("OpenTx", mock.Anything). + Return(fwdCtx, &tx, nil). + Once() + err = c.Post(`mutation notSkipped { name }`, &struct{ Name string }{}) + require.NoError(t, err) + opener.AssertExpectations(t) + }) + + t.Run("SkipIfHasFields", func(t *testing.T) { + var ( + tx mocks.Tx + opener mocks.TxOpener + ) + tx.On("Commit"). + Return(nil). + Once() + defer tx.AssertExpectations(t) + defer opener.AssertExpectations(t) + + srv := newServer(&opener, entgql.SkipIfHasFields("name")) + srv.AroundResponses(func(context.Context, graphql.ResponseHandler) *graphql.Response { + return &graphql.Response{Data: []byte(`{"name":"test"}`)} + }) + c := client.New(srv) + err := c.Post(`mutation { name }`, &struct{ Name string }{}) + require.NoError(t, err) + opener.AssertExpectations(t) + + opener.On("OpenTx", mock.Anything). + Return(fwdCtx, &tx, nil). + Once() + srv = newServer(&opener, entgql.SkipIfHasFields("work")) + srv.AroundResponses(func(context.Context, graphql.ResponseHandler) *graphql.Response { + return &graphql.Response{Data: []byte(`{"name":"test"}`)} + }) + c = client.New(srv) + err = c.Post(`mutation { name }`, &struct{ Name string }{}) + require.NoError(t, err) + opener.AssertExpectations(t) + }) + t.Run("Err", func(t *testing.T) { t.Parallel() var tx mocks.Tx @@ -88,7 +150,7 @@ func TestTransaction(t *testing.T) { Once() defer opener.AssertExpectations(t) - srv := newServer(&opener) + srv := newServer(&opener, nil) srv.AroundResponses(func(ctx context.Context, _ graphql.ResponseHandler) *graphql.Response { return graphql.ErrorResponse(ctx, "bad mutation") }) @@ -112,7 +174,7 @@ func TestTransaction(t *testing.T) { Once() defer opener.AssertExpectations(t) - srv := newServer(&opener) + srv := newServer(&opener, nil) srv.SetRecoverFunc(func(_ context.Context, err interface{}) error { return err.(error) }) @@ -133,7 +195,7 @@ func TestTransaction(t *testing.T) { Once() defer opener.AssertExpectations(t) - srv := newServer(&opener) + srv := newServer(&opener, nil) c := client.New(srv) err := c.Post(`mutation { name }`, &struct{ Name string }{}) require.Error(t, err)