From ea5dfe9a75dae2faef948bb3650703f4d0fb2c23 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 26 Jun 2024 11:56:01 -0400 Subject: [PATCH 1/2] Bump golang.org/x/net from 0.21.0 to 0.23.0 (#27) Bumps [golang.org/x/net](https://github.com/golang/net) from 0.21.0 to 0.23.0. - [Commits](https://github.com/golang/net/compare/v0.21.0...v0.23.0) --- updated-dependencies: - dependency-name: golang.org/x/net dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 6 +++--- go.sum | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/go.mod b/go.mod index f190eeb..a3fb0c9 100644 --- a/go.mod +++ b/go.mod @@ -20,9 +20,9 @@ require ( github.com/leodido/go-urn v1.4.0 // indirect github.com/mailru/easyjson v0.7.7 // indirect github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect - golang.org/x/crypto v0.19.0 // indirect - golang.org/x/net v0.21.0 // indirect - golang.org/x/sys v0.17.0 // indirect + golang.org/x/crypto v0.21.0 // indirect + golang.org/x/net v0.23.0 // indirect + golang.org/x/sys v0.18.0 // indirect golang.org/x/text v0.14.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 87c39f9..818bbbb 100644 --- a/go.sum +++ b/go.sum @@ -35,12 +35,12 @@ github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcU github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= -golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo= -golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= -golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= -golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= -golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= -golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= +golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= +golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs= +golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= +golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= +golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= From c9be9ffd67dfa3cb01a86d6f4cdae8e748cb024e Mon Sep 17 00:00:00 2001 From: ericvanlare Date: Mon, 15 Jul 2024 06:09:26 -0700 Subject: [PATCH 2/2] count all usage from all retries and failures (#29) * Count all usage from failures to unmarshal and validate json * usage counting: move provider-specific logic into provider chat files --- pkg/instructor/anthropic_chat.go | 63 +++++++++++++++++++++++++++-- pkg/instructor/chat.go | 19 +++++++-- pkg/instructor/cohere_chat.go | 61 +++++++++++++++++++++++++++- pkg/instructor/instructor.go | 7 ++++ pkg/instructor/openai_chat.go | 68 ++++++++++++++++++++++++++++++-- 5 files changed, 207 insertions(+), 11 deletions(-) diff --git a/pkg/instructor/anthropic_chat.go b/pkg/instructor/anthropic_chat.go index 683c124..3c7a57f 100644 --- a/pkg/instructor/anthropic_chat.go +++ b/pkg/instructor/anthropic_chat.go @@ -13,7 +13,10 @@ func (i *InstructorAnthropic) CreateMessages(ctx context.Context, request anthro resp, err := chatHandler(i, ctx, request, responseType) if err != nil { - return anthropic.MessagesResponse{}, err + if resp == nil { + return anthropic.MessagesResponse{}, err + } + return *nilAnthropicRespWithUsage(resp.(*anthropic.MessagesResponse)), err } response = *(resp.(*anthropic.MessagesResponse)) @@ -68,13 +71,13 @@ func (i *InstructorAnthropic) completionToolCall(ctx context.Context, request *a toolInput, err := json.Marshal(c.Input) if err != nil { - return "", nil, err + return "", nilAnthropicRespWithUsage(&resp), err } // TODO: handle more than 1 tool use return string(toolInput), &resp, nil } - return "", nil, errors.New("more than 1 tool response at a time is not implemented") + return "", nilAnthropicRespWithUsage(&resp), errors.New("more than 1 tool response at a time is not implemented") } @@ -103,3 +106,57 @@ Make sure to return an instance of the JSON, not the schema itself. return *text, &resp, nil } + +func (i *InstructorAnthropic) emptyResponseWithUsageSum(usage *UsageSum) interface{} { + return &anthropic.MessagesResponse{ + Usage: anthropic.MessagesUsage{ + InputTokens: usage.InputTokens, + OutputTokens: usage.OutputTokens, + }, + } +} + +func (i *InstructorAnthropic) emptyResponseWithResponseUsage(response interface{}) interface{} { + resp, ok := response.(*anthropic.MessagesResponse) + if !ok || resp == nil { + return nil + } + + return &anthropic.MessagesResponse{ + Usage: resp.Usage, + } +} + +func (i *InstructorAnthropic) addUsageSumToResponse(response interface{}, usage *UsageSum) (interface{}, error) { + resp, ok := response.(*anthropic.MessagesResponse) + if !ok { + return response, fmt.Errorf("internal type error: expected *anthropic.MessagesResponse, got %T", response) + } + + resp.Usage.InputTokens += usage.InputTokens + resp.Usage.OutputTokens += usage.OutputTokens + + return response, nil +} + +func (i *InstructorAnthropic) countUsageFromResponse(response interface{}, usage *UsageSum) *UsageSum { + resp, ok := response.(*anthropic.MessagesResponse) + if !ok { + return usage + } + + usage.InputTokens += resp.Usage.InputTokens + usage.OutputTokens += resp.Usage.OutputTokens + + return usage +} + +func nilAnthropicRespWithUsage(resp *anthropic.MessagesResponse) *anthropic.MessagesResponse { + if resp == nil { + return nil + } + + return &anthropic.MessagesResponse{ + Usage: resp.Usage, + } +} diff --git a/pkg/instructor/chat.go b/pkg/instructor/chat.go index d4a1478..4a6ec41 100644 --- a/pkg/instructor/chat.go +++ b/pkg/instructor/chat.go @@ -9,6 +9,12 @@ import ( "github.com/go-playground/validator/v10" ) +type UsageSum struct { + InputTokens int + OutputTokens int + TotalTokens int +} + func chatHandler(i Instructor, ctx context.Context, request interface{}, response any) (interface{}, error) { var err error @@ -20,12 +26,15 @@ func chatHandler(i Instructor, ctx context.Context, request interface{}, respons return nil, err } + // keep a running total of usage + usage := &UsageSum{} + for attempt := 0; attempt < i.MaxRetries(); attempt++ { text, resp, err := i.chat(ctx, request, schema) if err != nil { // no retry on non-marshalling/validation errors - return nil, err + return i.emptyResponseWithResponseUsage(resp), err } text = extractJSON(&text) @@ -37,6 +46,8 @@ func chatHandler(i Instructor, ctx context.Context, request interface{}, respons // // Currently, its just recalling with no new information // or attempt to fix the error with the last generated JSON + + i.countUsageFromResponse(resp, usage) continue } @@ -48,12 +59,14 @@ func chatHandler(i Instructor, ctx context.Context, request interface{}, respons if err != nil { // TODO: // add more sophisticated retry logic (send back validator error and parse error for model to fix). + + i.countUsageFromResponse(resp, usage) continue } } - return resp, nil + return i.addUsageSumToResponse(resp, usage) } - return nil, errors.New("hit max retry attempts") + return i.emptyResponseWithUsageSum(usage), errors.New("hit max retry attempts") } diff --git a/pkg/instructor/cohere_chat.go b/pkg/instructor/cohere_chat.go index 28373a2..5648143 100644 --- a/pkg/instructor/cohere_chat.go +++ b/pkg/instructor/cohere_chat.go @@ -17,7 +17,10 @@ func (i *InstructorCohere) Chat( resp, err := chatHandler(i, ctx, request, response) if err != nil { - return nil, err + if resp == nil { + return &cohere.NonStreamedChatResponse{}, err + } + return nilCohereRespWithUsage(resp.(*cohere.NonStreamedChatResponse)), err } return resp.(*cohere.NonStreamedChatResponse), nil @@ -80,6 +83,52 @@ func (i *InstructorCohere) addOrConcatJSONSystemPrompt(request *cohere.ChatReque } } +func (i *InstructorCohere) emptyResponseWithUsageSum(usage *UsageSum) interface{} { + return &cohere.NonStreamedChatResponse{ + Meta: &cohere.ApiMeta{ + Tokens: &cohere.ApiMetaTokens{ + InputTokens: toPtr(float64(usage.InputTokens)), + OutputTokens: toPtr(float64(usage.OutputTokens)), + }, + }, + } +} + +func (i *InstructorCohere) emptyResponseWithResponseUsage(response interface{}) interface{} { + resp, ok := response.(*cohere.NonStreamedChatResponse) + if !ok || resp == nil { + return nil + } + + return &cohere.NonStreamedChatResponse{ + Meta: resp.Meta, + } +} + +func (i *InstructorCohere) addUsageSumToResponse(response interface{}, usage *UsageSum) (interface{}, error) { + resp, ok := response.(*cohere.NonStreamedChatResponse) + if !ok { + return response, fmt.Errorf("internal type error: expected *cohere.NonStreamedChatResponse, got %T", response) + } + + *resp.Meta.Tokens.InputTokens += float64(usage.InputTokens) + *resp.Meta.Tokens.OutputTokens += float64(usage.OutputTokens) + + return response, nil +} + +func (i *InstructorCohere) countUsageFromResponse(response interface{}, usage *UsageSum) *UsageSum { + resp, ok := response.(*cohere.NonStreamedChatResponse) + if !ok { + return usage + } + + usage.InputTokens += int(*resp.Meta.Tokens.InputTokens) + usage.OutputTokens += int(*resp.Meta.Tokens.OutputTokens) + + return usage +} + func createCohereTools(schema *Schema) *cohere.Tool { tool := &cohere.Tool{ @@ -98,3 +147,13 @@ func createCohereTools(schema *Schema) *cohere.Tool { return tool } + +func nilCohereRespWithUsage(resp *cohere.NonStreamedChatResponse) *cohere.NonStreamedChatResponse { + if resp == nil { + return nil + } + + return &cohere.NonStreamedChatResponse{ + Meta: resp.Meta, + } +} diff --git a/pkg/instructor/instructor.go b/pkg/instructor/instructor.go index 47b86f7..b1e5688 100644 --- a/pkg/instructor/instructor.go +++ b/pkg/instructor/instructor.go @@ -29,4 +29,11 @@ type Instructor interface { request interface{}, schema *Schema, ) (<-chan string, error) + + // Usage counting + + emptyResponseWithUsageSum(usage *UsageSum) interface{} + emptyResponseWithResponseUsage(response interface{}) interface{} + addUsageSumToResponse(response interface{}, usage *UsageSum) (interface{}, error) + countUsageFromResponse(response interface{}, usage *UsageSum) *UsageSum } diff --git a/pkg/instructor/openai_chat.go b/pkg/instructor/openai_chat.go index 7505de4..b6f73a7 100644 --- a/pkg/instructor/openai_chat.go +++ b/pkg/instructor/openai_chat.go @@ -17,7 +17,10 @@ func (i *InstructorOpenAI) CreateChatCompletion( resp, err := chatHandler(i, ctx, request, responseType) if err != nil { - return openai.ChatCompletionResponse{}, err + if resp == nil { + return openai.ChatCompletionResponse{}, err + } + return *nilOpenaiRespWithUsage(resp.(*openai.ChatCompletionResponse)), err } response = *(resp.(*openai.ChatCompletionResponse)) @@ -69,7 +72,7 @@ func (i *InstructorOpenAI) chatToolCall(ctx context.Context, request *openai.Cha numTools := len(toolCalls) if numTools < 1 { - return "", nil, errors.New("recieved no tool calls from model, expected at least 1") + return "", nilOpenaiRespWithUsage(&resp), errors.New("received no tool calls from model, expected at least 1") } if numTools == 1 { @@ -84,14 +87,14 @@ func (i *InstructorOpenAI) chatToolCall(ctx context.Context, request *openai.Cha var jsonObj map[string]interface{} err = json.Unmarshal([]byte(toolCall.Function.Arguments), &jsonObj) if err != nil { - return "", nil, err + return "", nilOpenaiRespWithUsage(&resp), err } jsonArray[i] = jsonObj } resultJSON, err := json.Marshal(jsonArray) if err != nil { - return "", nil, err + return "", nilOpenaiRespWithUsage(&resp), err } return string(resultJSON), &resp, nil @@ -128,6 +131,53 @@ func (i *InstructorOpenAI) chatJSONSchema(ctx context.Context, request *openai.C return text, &resp, nil } +func (i *InstructorOpenAI) emptyResponseWithUsageSum(usage *UsageSum) interface{} { + return &openai.ChatCompletionResponse{ + Usage: openai.Usage{ + PromptTokens: usage.InputTokens, + CompletionTokens: usage.OutputTokens, + TotalTokens: usage.TotalTokens, + }, + } +} + +func (i *InstructorOpenAI) emptyResponseWithResponseUsage(response interface{}) interface{} { + resp, ok := response.(*openai.ChatCompletionResponse) + if !ok || resp == nil { + return nil + } + + return &openai.ChatCompletionResponse{ + Usage: resp.Usage, + } +} + +func (i *InstructorOpenAI) addUsageSumToResponse(response interface{}, usage *UsageSum) (interface{}, error) { + resp, ok := response.(*openai.ChatCompletionResponse) + if !ok { + return response, fmt.Errorf("internal type error: expected *openai.ChatCompletionResponse, got %T", response) + } + + resp.Usage.PromptTokens += usage.InputTokens + resp.Usage.CompletionTokens += usage.OutputTokens + resp.Usage.TotalTokens += usage.TotalTokens + + return response, nil +} + +func (i *InstructorOpenAI) countUsageFromResponse(response interface{}, usage *UsageSum) *UsageSum { + resp, ok := response.(*openai.ChatCompletionResponse) + if !ok { + return usage + } + + usage.InputTokens += resp.Usage.PromptTokens + usage.OutputTokens += resp.Usage.CompletionTokens + usage.TotalTokens += resp.Usage.TotalTokens + + return usage +} + func createJSONMessage(schema *Schema) *openai.ChatCompletionMessage { message := fmt.Sprintf(` Please respond with JSON in the following JSON schema: @@ -144,3 +194,13 @@ Make sure to return an instance of the JSON, not the schema itself return msg } + +func nilOpenaiRespWithUsage(resp *openai.ChatCompletionResponse) *openai.ChatCompletionResponse { + if resp == nil { + return nil + } + + return &openai.ChatCompletionResponse{ + Usage: resp.Usage, + } +}