8000 entgql: support skipping tx-opening by operation or field name by a8m · Pull Request #571 · ent/contrib · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

entgql: support skipping tx-opening by operation or field name #571

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 37 additions & 5 deletions entgql/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"context"
"database/sql/driver"
"errors"
"slices"
"sync"

"github.com/99designs/gqlgen/graphql"
Expand All @@ -39,8 +40,35 @@ func (f TxOpenerFunc) OpenTx(ctx context.Context) (context.Context, driver.Tx, e
return f(ctx)
}

// Transactioner for graphql mutations.
type Transactioner struct{ TxOpener }
type (
// Transactioner for graphql mutations.
Transactioner struct {
TxOpener
SkipTxFunc
}
// SkipTxFunc allows skipping operations from
// running under a transaction.
SkipTxFunc func(*ast.OperationDefinition) bool
)

// SkipOperations skips the given operation names from running
// under a transaction.
func SkipOperations(names ...string) SkipTxFunc {
return func(op *ast.OperationDefinition) bool {
return slices.Contains(names, op.Name)
}
}

// SkipIfHasFields skips the operation has a mutation field
// with the given names.
func SkipIfHasFields(names ...string) SkipTxFunc {
return func(op *ast.OperationDefinition) bool {
return slices.ContainsFunc(op.SelectionSet, func(s ast.Selection) bool {
f, ok := s.(*ast.Field)
return ok && slices.Contains(names, f.Name)
})
}
}

var _ interface {
graphql.HandlerExtension
Expand All @@ -62,8 +90,8 @@ func (t Transactioner) Validate(graphql.ExecutableSchema) error {
}

// MutateOperationContext serializes field resolvers during mutations.
func (Transactioner) MutateOperationContext(_ context.Context, oc *graphql.OperationContext) *gqlerror.Error {
if op := oc.Operation; op != nil && op.Operation == ast.Mutation {
func (t Transactioner) MutateOperationContext(_ context.Context, oc *graphql.OperationContext) *gqlerror.Error {
if !t.skipTx(oc.Operation) {
previous := oc.ResolverMiddleware
var mu sync.Mutex
oc.ResolverMiddleware = func(ctx context.Context, next graphql.Resolver) (interface{}, error) {
Expand All @@ -77,7 +105,7 @@ func (Transactioner) MutateOperationContext(_ context.Context, oc *graphql.Opera

// InterceptResponse runs graphql mutations under a transaction.
func (t Transactioner) InterceptResponse(ctx context.Context, next graphql.ResponseHandler) *graphql.Response {
if op := graphql.GetOperationContext(ctx).Operation; op == nil || op.Operation != ast.Mutation {
if t.skipTx(graphql.GetOperationContext(ctx).Operation) {
return next(ctx)
}
txCtx, tx, err := t.OpenTx(ctx)
Expand Down Expand Up @@ -108,3 +136,7 @@ func (t Transactioner) InterceptResponse(ctx context.Context, next graphql.Respo
}
return rsp
}

func (t Transactioner) skipTx(op *ast.OperationDefinition) bool {
return op == nil || op.Operation != ast.Mutation || (t.SkipTxFunc != nil && t.SkipTxFunc(op))
}
76 changes: 69 additions & 7 deletions entgql/transaction_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ import (
)

func TestTransaction(t *testing.T) {
newServer := func(opener entgql.TxOpener) *testserver.TestServer {
newServer := func(opener entgql.TxOpener, skip entgql.SkipTxFunc) *testserver.TestServer {
srv := testserver.New()
srv.AddTransport(transport.POST{})
srv.Use(entgql.Transactioner{TxOpener: opener})
srv.Use(entgql.Transactioner{TxOpener: opener, SkipTxFunc: skip})
return srv
}
fwdCtx := func(ctx context.Context) context.Context {
Expand All @@ -44,7 +44,7 @@ func TestTransaction(t *testing.T) {
t.Parallel()
var opener mocks.TxOpener
defer opener.AssertExpectations(t)
srv := newServer(&opener)
srv := newServer(&opener, nil)

c := client.New(srv)
err := c.Post(`query { name }`, &struct{ Name string }{})
Expand All @@ -65,7 +65,7 @@ func TestTransaction(t *testing.T) {
Once()
defer opener.AssertExpectations(t)

srv := newServer(&opener)
srv := newServer(&opener, nil)
srv.AroundResponses(func(context.Context, graphql.ResponseHandler) *graphql.Response {
return &graphql.Response{Data: []byte(`{"name":"test"}`)}
})
Expand All @@ -74,6 +74,68 @@ func TestTransaction(t *testing.T) {
err := c.Post(`mutation { name }`, &struct{ Name string }{})
require.NoError(t, err)
})

t.Run("SkipOperation", func(t *testing.T) {
var (
tx mocks.Tx
opener mocks.TxOpener
)
tx.On("Commit").
Return(nil).
Once()
defer tx.AssertExpectations(t)

srv := newServer(&opener, entgql.SkipOperations("skipped"))
srv.AroundResponses(func(context.Context, graphql.ResponseHandler) *graphql.Response {
return &graphql.Response{Data: []byte(`{"name":"test"}`)}
})

c := client.New(srv)
err := c.Post(`mutation skipped { name }`, &struct{ Name string }{})
require.NoError(t, err)
opener.AssertExpectations(t)

opener.On("OpenTx", mock.Anything).
Return(fwdCtx, &tx, nil).
Once()
err = c.Post(`mutation notSkipped { name }`, &struct{ Name string }{})
require.NoError(t, err)
opener.AssertExpectations(t)
})

t.Run("SkipIfHasFields", func(t *testing.T) {
var (
tx mocks.Tx
opener mocks.TxOpener
)
tx.On("Commit").
Return(nil).
Once()
defer tx.AssertExpectations(t)
defer opener.AssertExpectations(t)

srv := newServer(&opener, entgql.SkipIfHasFields("name"))
srv.AroundResponses(func(context.Context, graphql.ResponseHandler) *graphql.Response {
return &graphql.Response{Data: []byte(`{"name":"test"}`)}
})
c := client.New(srv)
err := c.Post(`mutation { name }`, &struct{ Name string }{})
require.NoError(t, err)
opener.AssertExpectations(t)

opener.On("OpenTx", mock.Anything).
Return(fwdCtx, &tx, nil).
Once()
srv = newServer(&opener, entgql.SkipIfHasFields("work"))
srv.AroundResponses(func(context.Context, graphql.ResponseHandler) *graphql.Response {
return &graphql.Response{Data: []byte(`{"name":"test"}`)}
})
c = client.New(srv)
err = c.Post(`mutation { name }`, &struct{ Name string }{})
require.NoError(t, err)
opener.AssertExpectations(t)
})

t.Run("Err", func(t *testing.T) {
t.Parallel()
var tx mocks.Tx
Expand All @@ -88,7 +150,7 @@ func TestTransaction(t *testing.T) {
Once()
defer opener.AssertExpectations(t)

srv := newServer(&opener)
srv := newServer(&opener, nil)
srv.AroundResponses(func(ctx context.Context, _ graphql.ResponseHandler) *graphql.Response {
return graphql.ErrorResponse(ctx, "bad mutation")
})
Expand All @@ -112,7 +174,7 @@ func TestTransaction(t *testing.T) {
Once()
defer opener.AssertExpectations(t)

srv := newServer(&opener)
srv := newServer(&opener, nil)
srv.SetRecoverFunc(func(_ context.Context, err interface{}) error {
return err.(error)
})
Expand All @@ -133,7 +195,7 @@ func TestTransaction(t *testing.T) {
Once()
defer opener.AssertExpectations(t)

srv := newServer(&opener)
srv := newServer(&opener, nil)
c := client.New(srv)
err := c.Post(`mutation { name }`, &struct{ Name string }{})
require.Error(t, err)
Expand Down
Loading
0