diff --git a/README.md b/README.md index 0e3096a..9471530 100644 --- a/README.md +++ b/README.md @@ -22,8 +22,8 @@ Implement the `sqlplus.Hook` interface and wrap it with `sqlplus.New(d driver.Dr # 👐ecosystem - [sqltrace](https://github.com/chenquan/sqltrace): A low-code intrusion library that provides SQL tracing capabilities, suitable for any - relational database (Sqlite3, MySQL, Oracle, SQL Server, PostgreSQL, TiDB, etc.) and ORM libraries for various + relational database (Sqlite3, MySQL, Oracle, SQL Server, PostgreSQL, TiDB, TDengine, etc.) and ORM libraries for various relational database (gorm, xorm, sqlx, etc.) - [sqlbreaker](https://github.com/chenquan/sqlbreaker): A low-code intrusion library that provides SQL breaker capabilities, suitable for any - relational database (Sqlite3, MySQL, Oracle, SQL Server, PostgreSQL, TiDB, etc.) and ORM libraries for various + relational database (Sqlite3, MySQL, Oracle, SQL Server, PostgreSQL, TiDB, TDengine, etc.) and ORM libraries for various relational database (gorm, xorm, sqlx, etc.) diff --git a/conn.go b/conn.go index 952d2eb..4c737c5 100644 --- a/conn.go +++ b/conn.go @@ -13,13 +13,17 @@ var ( _ driver.ExecerContext = (*conn)(nil) ) -type conn struct { - driver.Conn - ConnHook -} +type ( + conn struct { + driver.Conn + ConnHook + } + connKey struct{} +) func (c *conn) Close() (err error) { - ctx, err := c.BeforeClose(context.Background(), nil) + ctx := c.newConnContext(context.Background()) + ctx, err = c.BeforeClose(ctx, nil) defer func() { _, err = c.AfterClose(ctx, err) }() @@ -31,6 +35,7 @@ func (c *conn) Close() (err error) { } func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (result driver.Result, err error) { + ctx = c.newConnContext(ctx) ctx, query, args, err = c.BeforeExecContext(ctx, query, args, nil) defer func() { _, result, err = c.AfterExecContext(ctx, query, args, result, err) @@ -60,6 +65,7 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name } func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (rows driver.Rows, err error) { + ctx = c.newConnContext(ctx) ctx, query, args, err = c.BeforeQueryContext(ctx, query, args, nil) defer func() { _, rows, err = c.AfterQueryContext(ctx, query, args, rows, err) @@ -90,6 +96,7 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam } func (c *conn) PrepareContext(ctx context.Context, query string) (s driver.Stmt, err error) { + ctx = c.newConnContext(ctx) ctx, query, err = c.BeforePrepareContext(ctx, query, nil) defer func() { _, s, err = c.AfterPrepareContext(ctx, query, s, err) @@ -113,6 +120,7 @@ func (c *conn) PrepareContext(ctx context.Context, query string) (s driver.Stmt, } func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (dd driver.Tx, err error) { + ctx = c.newConnContext(ctx) ctx, opts, err = c.BeforeBeginTx(ctx, opts, nil) defer func() { _, dd, err = c.AfterBeginTx(ctx, opts, dd, err) @@ -147,3 +155,24 @@ func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) { return dargs, nil } + +func ConnFromContext(ctx context.Context) interface { + driver.Conn + driver.ConnPrepareContext + driver.ConnBeginTx +} { + value := ctx.Value(connKey{}) + if value == nil { + return nil + } + + return value.(interface { + driver.Conn + driver.ConnPrepareContext + driver.ConnBeginTx + }) +} + +func (c *conn) newConnContext(ctx context.Context) context.Context { + return context.WithValue(ctx, connKey{}, c) +} diff --git a/conn_test.go b/conn_test.go index d34bf04..5c994ac 100644 --- a/conn_test.go +++ b/conn_test.go @@ -83,6 +83,16 @@ func (m *mockConnBeginTx) BeginTx(_ context.Context, _ driver.TxOptions) (driver return &mockTx{}, nil } +var _ driver.ConnPrepareContext = (*mockConnPrepareContext)(nil) + +type mockConnPrepareContext struct { + *mockConn +} + +func (m *mockConnPrepareContext) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { + return &mockStmt{}, nil +} + // ----------------- func createMockConn() (*conn, *mockHook) { diff --git a/go.mod b/go.mod index 30e4d69..027801d 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,4 @@ module github.com/chenquan/sqlplus go 1.16 -require github.com/stretchr/testify v1.8.0 +require github.com/stretchr/testify v1.8.1 diff --git a/go.sum b/go.sum index 5164829..2ec90f7 100644 --- a/go.sum +++ b/go.sum @@ -5,9 +5,11 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/hook.go b/hook.go index 18cc6d4..e370e57 100644 --- a/hook.go +++ b/hook.go @@ -24,6 +24,7 @@ type ( BeforePrepareContext(ctx context.Context, query string, err error) (context.Context, string, error) AfterPrepareContext(ctx context.Context, query string, ds driver.Stmt, err error) (context.Context, driver.Stmt, error) + BeforeClose(ctx context.Context, err error) (context.Context, error) AfterClose(ctx context.Context, err error) (context.Context, error) } @@ -34,12 +35,14 @@ type ( TxHook interface { BeforeCommit(ctx context.Context, err error) (context.Context, error) AfterCommit(ctx context.Context, err error) (context.Context, error) + BeforeRollback(ctx context.Context, err error) (context.Context, error) AfterRollback(ctx context.Context, err error) (context.Context, error) } StmtHook interface { BeforeStmtQueryContext(ctx context.Context, query string, args []driver.NamedValue, err error) (context.Context, []driver.NamedValue, error) AfterStmtQueryContext(ctx context.Context, query string, args []driver.NamedValue, rows driver.Rows, err error) (context.Context, driver.Rows, error) + BeforeStmtExecContext(ctx context.Context, query string, args []driver.NamedValue, err error) (context.Context, []driver.NamedValue, error) AfterStmtExecContext(ctx context.Context, query string, args []driver.NamedValue, r driver.Result, err error) (context.Context, driver.Result, error) } diff --git a/hook_test.go b/hook_test.go index be3788d..db8ebf0 100644 --- a/hook_test.go +++ b/hook_test.go @@ -74,10 +74,6 @@ func (m *mockHook) AfterPrepareContext(ctx context.Context, _ string, ds driver. func (m *mockHook) BeforeCommit(ctx context.Context, err error) (context.Context, error) { m.Write("BeforeCommit") - txContext := TxContextFromContext(ctx) - if txContext == nil { - panic("txContext is nil") - } prepareContext := PrepareContextFromContext(ctx) if prepareContext != nil { @@ -89,10 +85,6 @@ func (m *mockHook) BeforeCommit(ctx context.Context, err error) (context.Context func (m *mockHook) AfterCommit(ctx context.Context, err error) (context.Context, error) { m.Write("AfterCommit") - txContext := TxContextFromContext(ctx) - if txContext == nil { - panic("txContext is nil") - } prepareContext := PrepareContextFromContext(ctx) if prepareContext != nil { @@ -104,10 +96,6 @@ func (m *mockHook) AfterCommit(ctx context.Context, err error) (context.Context, func (m *mockHook) BeforeRollback(ctx context.Context, err error) (context.Context, error) { m.Write("BeforeRollback") - txContext := TxContextFromContext(ctx) - if txContext == nil { - panic("txContext is nil") - } prepareContext := PrepareContextFromContext(ctx) if prepareContext != nil { @@ -119,10 +107,6 @@ func (m *mockHook) BeforeRollback(ctx context.Context, err error) (context.Conte func (m *mockHook) AfterRollback(ctx context.Context, err error) (context.Context, error) { m.Write("AfterRollback") - txContext := TxContextFromContext(ctx) - if txContext == nil { - panic("txContext is nil") - } prepareContext := PrepareContextFromContext(ctx) if prepareContext != nil { @@ -139,11 +123,6 @@ func (m *mockHook) BeforeStmtQueryContext(ctx context.Context, _ string, args [] panic("prepareContext is nil") } - txContext := TxContextFromContext(ctx) - if txContext != nil { - panic("txContext is not nil") - } - return ctx, args, err } @@ -154,11 +133,6 @@ func (m *mockHook) AfterStmtQueryContext(ctx context.Context, _ string, _ []driv panic("prepareContext is nil") } - txContext := TxContextFromContext(ctx) - if txContext != nil { - panic("txContext is not nil") - } - return ctx, rows, err } @@ -169,11 +143,6 @@ func (m *mockHook) BeforeStmtExecContext(ctx context.Context, _ string, args []d panic("prepareContext is nil") } - txContext := TxContextFromContext(ctx) - if txContext != nil { - panic("txContext is not nil") - } - return ctx, args, err } @@ -184,11 +153,6 @@ func (m *mockHook) AfterStmtExecContext(ctx context.Context, _ string, _ []drive panic("prepareContext is nil") } - txContext := TxContextFromContext(ctx) - if txContext != nil { - panic("txContext is not nil") - } - return ctx, r, err } diff --git a/stmt.go b/stmt.go index 2c24919..9069864 100644 --- a/stmt.go +++ b/stmt.go @@ -19,6 +19,7 @@ type ( prepareContext context.Context } prepareContextKey struct{} + stmtKey struct{} ) func PrepareContextFromContext(ctx context.Context) context.Context { @@ -30,11 +31,28 @@ func PrepareContextFromContext(ctx context.Context) context.Context { return nil } +func StmtFromContext(ctx context.Context) interface { + driver.Stmt + driver.StmtExecContext + driver.StmtQueryContext +} { + value := ctx.Value(stmtKey{}) + if value != nil { + return nil + } + + return value.(interface { + driver.Stmt + driver.StmtExecContext + driver.StmtQueryContext + }) +} + // ----------------- func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (rows driver.Rows, err error) { query := s.query - ctx = context.WithValue(ctx, prepareContextKey{}, s.prepareContext) + ctx = s.newStmtContext(ctx) ctx, args, err = s.BeforeStmtQueryContext(ctx, query, args, nil) defer func() { _, rows, err = s.AfterStmtQueryContext(ctx, query, args, rows, err) @@ -59,7 +77,7 @@ func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (rows func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (r driver.Result, err error) { query := s.query - ctx = context.WithValue(ctx, prepareContextKey{}, s.prepareContext) + ctx = s.newStmtContext(ctx) ctx, args, err = s.BeforeStmtExecContext(ctx, query, args, nil) defer func() { _, r, err = s.AfterStmtExecContext(ctx, query, args, r, err) @@ -80,3 +98,8 @@ func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (r dri return s.Exec(value) } } + +func (s *stmt) newStmtContext(ctx context.Context) context.Context { + ctx = context.WithValue(ctx, prepareContextKey{}, s.prepareContext) + return context.WithValue(ctx, stmtKey{}, s) +} diff --git a/tx.go b/tx.go index ae6c91d..fbd6818 100644 --- a/tx.go +++ b/tx.go @@ -11,22 +11,10 @@ type ( TxHook txContext context.Context } - txContextKey struct{} ) -func TxContextFromContext(ctx context.Context) context.Context { - value := ctx.Value(txContextKey{}) - if value != nil { - return value.(context.Context) - } - - return nil -} - -// ----------------- - func (t *tx) Commit() (err error) { - ctx := context.WithValue(context.Background(), txContextKey{}, t.txContext) + ctx := t.txContext ctx, err = t.BeforeCommit(ctx, nil) defer func() { _, err = t.AfterCommit(ctx, err) @@ -44,7 +32,7 @@ func (t *tx) Commit() (err error) { } func (t *tx) Rollback() (err error) { - ctx := context.WithValue(context.Background(), txContextKey{}, t.txContext) + ctx := t.txContext ctx, err = t.BeforeRollback(ctx, nil) defer func() { _, err = t.AfterRollback(ctx, err)