From 527992e6691e50e495167689ce068ec5e203002e Mon Sep 17 00:00:00 2001 From: chenquan Date: Sun, 11 Sep 2022 18:12:58 +0800 Subject: [PATCH] feat: support multiple hooks --- conn.go | 8 +- connector.go | 2 +- hook.go | 207 ++++++++++++++++++++++++++++++++++++++++++++------- stmt.go | 23 +++--- tx.go | 4 +- 5 files changed, 196 insertions(+), 48 deletions(-) diff --git a/conn.go b/conn.go index 92299b0..5fc3336 100644 --- a/conn.go +++ b/conn.go @@ -19,7 +19,7 @@ type conn struct { } func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (result driver.Result, err error) { - ctx, err = c.BeforeExecContext(ctx, query, args) + ctx, query, args, err = c.BeforeExecContext(ctx, query, args, nil) defer func() { _, result, err = c.AfterExecContext(ctx, query, args, result, err) }() @@ -48,7 +48,7 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name } func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (rows driver.Rows, err error) { - ctx, err = c.BeforeQueryContext(ctx, query, args) + ctx, query, args, err = c.BeforeQueryContext(ctx, query, args, nil) defer func() { _, rows, err = c.AfterQueryContext(ctx, query, args, rows, err) }() @@ -78,7 +78,7 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam } func (c *conn) PrepareContext(ctx context.Context, query string) (s driver.Stmt, err error) { - ctx, err = c.BeforePrepareContext(ctx, query) + ctx, query, err = c.BeforePrepareContext(ctx, query, nil) defer func() { _, s, err = c.AfterPrepareContext(ctx, query, s, err) }() @@ -101,7 +101,7 @@ func (c *conn) PrepareContext(ctx context.Context, query string) (s driver.Stmt, } func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (dd driver.Tx, err error) { - ctx, err = c.BeforeBeginTx(ctx, opts) + ctx, opts, err = c.BeforeBeginTx(ctx, opts, nil) defer func() { _, dd, err = c.AfterBeginTx(ctx, opts, dd, err) }() diff --git a/connector.go b/connector.go index e74ee13..172df63 100644 --- a/connector.go +++ b/connector.go @@ -13,7 +13,7 @@ type connector struct { } func (c *connector) Connect(ctx context.Context) (dc driver.Conn, err error) { - ctx, err = c.BeforeConnect(ctx) + ctx, err = c.BeforeConnect(ctx, nil) defer func() { _, dc, err = c.AfterConnect(ctx, dc, err) }() diff --git a/hook.go b/hook.go index b0122a4..bc530b6 100644 --- a/hook.go +++ b/hook.go @@ -5,39 +5,192 @@ import ( "database/sql/driver" ) -type Hook interface { - ConnectorHook - ConnHook - TxHook - StmtHook +type ( + Hook interface { + ConnectorHook + ConnHook + TxHook + StmtHook + } + ConnHook interface { + BeforeExecContext(ctx context.Context, query string, args []driver.NamedValue, err error) (context.Context, string, []driver.NamedValue, error) + AfterExecContext(ctx context.Context, query string, args []driver.NamedValue, r driver.Result, err error) (context.Context, driver.Result, error) + + BeforeBeginTx(ctx context.Context, opts driver.TxOptions, err error) (context.Context, driver.TxOptions, error) + AfterBeginTx(ctx context.Context, opts driver.TxOptions, dd driver.Tx, err error) (context.Context, driver.Tx, error) + + BeforeQueryContext(ctx context.Context, query string, args []driver.NamedValue, err error) (context.Context, string, []driver.NamedValue, error) + AfterQueryContext(ctx context.Context, query string, args []driver.NamedValue, rows driver.Rows, err error) (context.Context, driver.Rows, error) + + BeforePrepareContext(ctx context.Context, query string, err error) (context.Context, string, error) + AfterPrepareContext(ctx context.Context, query string, s driver.Stmt, err error) (context.Context, driver.Stmt, error) + } + ConnectorHook interface { + BeforeConnect(ctx context.Context, err error) (context.Context, error) + AfterConnect(ctx context.Context, dc driver.Conn, err error) (context.Context, driver.Conn, error) + } + TxHook interface { + BeforeCommit(ctx context.Context, err error) (context.Context, error) + AfterCommit(ctx context.Context, err error) (context.Context, error) + BeforeRollback(ctx context.Context, err error) (context.Context, error) + AfterRollback(ctx context.Context, err error) (context.Context, error) + } + StmtHook interface { + BeforeStmtQueryContext(ctx context.Context, query string, args []driver.NamedValue, err error) (context.Context, []driver.NamedValue, error) + AfterStmtQueryContext(ctx context.Context, query string, args []driver.NamedValue, rows driver.Rows, err error) (context.Context, driver.Rows, error) + BeforeStmtExecContext(ctx context.Context, query string, args []driver.NamedValue, err error) (context.Context, []driver.NamedValue, error) + AfterStmtExecContext(ctx context.Context, query string, args []driver.NamedValue, r driver.Result, err error) (context.Context, driver.Result, error) + } + Hooks struct { + hooks []Hook + } +) + +func NewMultiHook(hooks ...Hook) Hook { + return &Hooks{hooks: hooks} +} + +func (h *Hooks) BeforeConnect(ctx context.Context, err error) (context.Context, error) { + for _, hook := range h.hooks { + ctx, err = hook.BeforeConnect(ctx, err) + } + + return ctx, err +} + +func (h *Hooks) AfterConnect(ctx context.Context, dc driver.Conn, err error) (context.Context, driver.Conn, error) { + for _, hook := range h.hooks { + ctx, dc, err = hook.AfterConnect(ctx, dc, err) + } + + return ctx, dc, err +} + +func (h *Hooks) BeforeExecContext(ctx context.Context, query string, args []driver.NamedValue, err error) (context.Context, string, []driver.NamedValue, error) { + for _, hook := range h.hooks { + ctx, query, args, err = hook.BeforeExecContext(ctx, query, args, err) + } + + return ctx, query, args, err +} + +func (h *Hooks) AfterExecContext(ctx context.Context, query string, args []driver.NamedValue, r driver.Result, err error) (context.Context, driver.Result, error) { + for _, hook := range h.hooks { + ctx, r, err = hook.AfterExecContext(ctx, query, args, r, err) + } + + return ctx, r, err +} + +func (h *Hooks) BeforeBeginTx(ctx context.Context, opts driver.TxOptions, err error) (context.Context, driver.TxOptions, error) { + for _, hook := range h.hooks { + ctx, opts, err = hook.BeforeBeginTx(ctx, opts, err) + } + + return ctx, opts, err +} + +func (h *Hooks) AfterBeginTx(ctx context.Context, opts driver.TxOptions, dd driver.Tx, err error) (context.Context, driver.Tx, error) { + for _, hook := range h.hooks { + ctx, dd, err = hook.AfterBeginTx(ctx, opts, dd, err) + } + + return ctx, dd, err +} + +func (h *Hooks) BeforeQueryContext(ctx context.Context, query string, args []driver.NamedValue, err error) (context.Context, string, []driver.NamedValue, error) { + for _, hook := range h.hooks { + ctx, query, args, err = hook.BeforeQueryContext(ctx, query, args, err) + } + + return ctx, query, args, err } -type ConnHook interface { - BeforeExecContext(ctx context.Context, query string, args []driver.NamedValue) (context.Context, error) - AfterExecContext(ctx context.Context, query string, args []driver.NamedValue, r driver.Result, err error) (context.Context, driver.Result, error) - BeforeBeginTx(ctx context.Context, opts driver.TxOptions) (context.Context, error) - AfterBeginTx(ctx context.Context, opts driver.TxOptions, dd driver.Tx, err error) (context.Context, driver.Tx, error) - BeforeQueryContext(ctx context.Context, query string, args []driver.NamedValue) (context.Context, error) - AfterQueryContext(ctx context.Context, query string, args []driver.NamedValue, rows driver.Rows, err error) (context.Context, driver.Rows, error) - BeforePrepareContext(ctx context.Context, query string) (context.Context, error) - AfterPrepareContext(ctx context.Context, query string, s driver.Stmt, err error) (context.Context, driver.Stmt, error) +func (h *Hooks) AfterQueryContext(ctx context.Context, query string, args []driver.NamedValue, rows driver.Rows, err error) (context.Context, driver.Rows, error) { + for _, hook := range h.hooks { + ctx, rows, err = hook.AfterQueryContext(ctx, query, args, rows, err) + + } + + return ctx, rows, err +} + +func (h *Hooks) BeforePrepareContext(ctx context.Context, query string, err error) (context.Context, string, error) { + for _, hook := range h.hooks { + ctx, query, err = hook.BeforePrepareContext(ctx, query, err) + } + + return ctx, query, err +} + +func (h *Hooks) AfterPrepareContext(ctx context.Context, query string, s driver.Stmt, err error) (context.Context, driver.Stmt, error) { + for _, hook := range h.hooks { + ctx, s, err = hook.AfterPrepareContext(ctx, query, s, err) + } + + return ctx, s, err +} + +func (h *Hooks) BeforeCommit(ctx context.Context, err error) (context.Context, error) { + for _, hook := range h.hooks { + ctx, err = hook.BeforeCommit(ctx, err) + } + + return ctx, err +} + +func (h *Hooks) AfterCommit(ctx context.Context, err error) (context.Context, error) { + for _, hook := range h.hooks { + ctx, err = hook.AfterCommit(ctx, err) + } + + return ctx, err } -type ConnectorHook interface { - BeforeConnect(ctx context.Context) (context.Context, error) - AfterConnect(ctx context.Context, dc driver.Conn, err error) (context.Context, driver.Conn, error) +func (h *Hooks) BeforeRollback(ctx context.Context, err error) (context.Context, error) { + for _, hook := range h.hooks { + ctx, err = hook.BeforeRollback(ctx, err) + } + + return ctx, err } -type TxHook interface { - BeforeCommit(ctx context.Context) (context.Context, error) - AfterCommit(ctx context.Context, err error) (context.Context, error) - BeforeRollback(ctx context.Context) (context.Context, error) - AfterRollback(ctx context.Context, err error) (context.Context, error) +func (h *Hooks) AfterRollback(ctx context.Context, err error) (context.Context, error) { + for _, hook := range h.hooks { + ctx, err = hook.AfterRollback(ctx, err) + } + + return ctx, err } -type StmtHook interface { - BeforeStmtQueryContext(ctx context.Context, query string, args []driver.NamedValue) (context.Context, error) - AfterStmtQueryContext(ctx context.Context, query string, args []driver.NamedValue, rows driver.Rows, err error) (context.Context, driver.Rows, error) - BeforeStmtExecContext(ctx context.Context, query string, args []driver.NamedValue) (context.Context, error) - AfterStmtExecContext(ctx context.Context, query string, args []driver.NamedValue, r driver.Result, err error) (context.Context, driver.Result, error) +func (h *Hooks) BeforeStmtQueryContext(ctx context.Context, query string, args []driver.NamedValue, err error) (context.Context, []driver.NamedValue, error) { + for _, hook := range h.hooks { + ctx, args, err = hook.BeforeStmtQueryContext(ctx, query, args, err) + } + + return ctx, args, err +} + +func (h *Hooks) AfterStmtQueryContext(ctx context.Context, query string, args []driver.NamedValue, rows driver.Rows, err error) (context.Context, driver.Rows, error) { + for _, hook := range h.hooks { + ctx, rows, err = hook.AfterStmtQueryContext(ctx, query, args, rows, err) + } + + return ctx, rows, err +} + +func (h *Hooks) BeforeStmtExecContext(ctx context.Context, query string, args []driver.NamedValue, err error) (context.Context, []driver.NamedValue, error) { + for _, hook := range h.hooks { + ctx, args, err = hook.BeforeStmtExecContext(ctx, query, args, err) + } + + return ctx, args, err +} + +func (h *Hooks) AfterStmtExecContext(ctx context.Context, query string, args []driver.NamedValue, r driver.Result, err error) (context.Context, driver.Result, error) { + for _, hook := range h.hooks { + ctx, r, err = hook.AfterStmtExecContext(ctx, query, args, r, err) + } + + return ctx, r, err } diff --git a/stmt.go b/stmt.go index 8c34192..078b48f 100644 --- a/stmt.go +++ b/stmt.go @@ -18,9 +18,10 @@ type stmt struct { } func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (rows driver.Rows, err error) { - ctx, err = s.BeforeStmtQueryContext(ctx, s.query, args) + query := s.query + ctx, args, err = s.BeforeStmtQueryContext(ctx, query, args, nil) defer func() { - _, rows, err = s.AfterStmtQueryContext(ctx, s.query, args, rows, err) + _, rows, err = s.AfterStmtQueryContext(ctx, query, args, rows, err) }() if err != nil { return nil, err @@ -29,22 +30,20 @@ func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (rows switch ss := s.Stmt.(type) { case driver.StmtQueryContext: return ss.QueryContext(ctx, args) - case interface { - Query(args []driver.Value) (driver.Rows, error) - }: + default: value, err := namedValueToValue(args) if err != nil { return nil, err } - return ss.Query(value) + return s.Query(value) } - return nil, errNoInterfaceImplementation } func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (r driver.Result, err error) { - ctx, err = s.BeforeStmtExecContext(ctx, s.query, args) + + ctx, args, err = s.BeforeStmtExecContext(ctx, s.query, args, nil) defer func() { _, r, err = s.AfterStmtExecContext(ctx, s.query, args, r, err) }() @@ -55,16 +54,12 @@ func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (r dri switch ss := s.Stmt.(type) { case driver.StmtExecContext: return ss.ExecContext(ctx, args) - case interface { - Exec(args []driver.Value) (driver.Result, error) - }: + default: value, err := namedValueToValue(args) if err != nil { return nil, err } - return ss.Exec(value) + return s.Exec(value) } - - return nil, errNoInterfaceImplementation } diff --git a/tx.go b/tx.go index c93fad2..bd1687a 100644 --- a/tx.go +++ b/tx.go @@ -12,7 +12,7 @@ type tx struct { func (t *tx) Commit() (err error) { ctx := context.Background() - ctx, err = t.BeforeCommit(ctx) + ctx, err = t.BeforeCommit(ctx, nil) defer func() { _, err = t.AfterCommit(ctx, err) }() @@ -30,7 +30,7 @@ func (t *tx) Commit() (err error) { func (t *tx) Rollback() (err error) { ctx := context.Background() - ctx, err = t.BeforeRollback(ctx) + ctx, err = t.BeforeRollback(ctx, nil) defer func() { _, err = t.AfterRollback(ctx, err) }()