From e821067b4c00bac1541db9f1a3b61f775124a69f Mon Sep 17 00:00:00 2001 From: chenquan Date: Sun, 11 Sep 2022 21:22:08 +0800 Subject: [PATCH] feat: support get prepare context --- conn.go | 2 +- stmt.go | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/conn.go b/conn.go index 5fc3336..96cf567 100644 --- a/conn.go +++ b/conn.go @@ -97,7 +97,7 @@ func (c *conn) PrepareContext(ctx context.Context, query string) (s driver.Stmt, return st, err } - return &stmt{Stmt: st, StmtHook: c.ConnHook.(StmtHook), query: query}, nil + return &stmt{Stmt: st, StmtHook: c.ConnHook.(StmtHook), query: query, prepareContext: ctx}, nil } func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (dd driver.Tx, err error) { diff --git a/stmt.go b/stmt.go index 078b48f..ab77ca2 100644 --- a/stmt.go +++ b/stmt.go @@ -11,14 +11,22 @@ var ( _ driver.StmtQueryContext = (*stmt)(nil) ) +type prepareContextKey struct{} + +func PrepareContextFromContext(ctx context.Context) context.Context { + return ctx.Value(prepareContextKey{}).(context.Context) +} + type stmt struct { driver.Stmt query string StmtHook + prepareContext context.Context } func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (rows driver.Rows, err error) { query := s.query + ctx = context.WithValue(ctx, prepareContextKey{}, s.prepareContext) ctx, args, err = s.BeforeStmtQueryContext(ctx, query, args, nil) defer func() { _, rows, err = s.AfterStmtQueryContext(ctx, query, args, rows, err)