diff --git a/go.mod b/go.mod index 6fed35e..e35c704 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,6 @@ module github.com/snyk/code-client-go -go 1.23.0 - -toolchain go1.23.6 +go 1.23.6 require ( github.com/go-git/go-git/v5 v5.14.0 diff --git a/llm/api_client.go b/llm/api_client.go index f39f1e5..a600b1a 100644 --- a/llm/api_client.go +++ b/llm/api_client.go @@ -4,20 +4,19 @@ import ( "bytes" "context" "encoding/json" + "github.com/snyk/code-client-go/observability" "io" "net/http" "net/url" - - "github.com/snyk/code-client-go/observability" ) -const ( +var ( completeStatus = "COMPLETE" failedToObtainRequestIdString = "Failed to obtain request id. " defaultEndpointURL = "http://localhost:10000/explain" ) -func (d *DeepCodeLLMBindingImpl) runExplain(ctx context.Context, options ExplainOptions) (explainResponse, error) { +func (d *DeepCodeLLMBindingImpl) runExplain(ctx context.Context, options ExplainOptions) (Explanations, error) { span := d.instrumentor.StartSpan(ctx, "code.RunExplain") defer span.Finish() @@ -25,7 +24,7 @@ func (d *DeepCodeLLMBindingImpl) runExplain(ctx context.Context, options Explain logger := d.logger.With().Str("method", "code.RunExplain").Str("requestId", requestId).Logger() if err != nil { logger.Err(err).Msg(failedToObtainRequestIdString + err.Error()) - return explainResponse{}, err + return Explanations{}, err } logger.Debug().Msg("API: Retrieving explain for bundle") @@ -34,7 +33,7 @@ func (d *DeepCodeLLMBindingImpl) runExplain(ctx context.Context, options Explain requestBody, err := d.explainRequestBody(&options) if err != nil { logger.Err(err).Str("requestBody", string(requestBody)).Msg("error creating request body") - return explainResponse{}, err + return Explanations{}, err } logger.Debug().Str("payload body: %s\n", string(requestBody)).Msg("Marshaled payload") @@ -43,13 +42,13 @@ func (d *DeepCodeLLMBindingImpl) runExplain(ctx context.Context, options Explain u, err = url.Parse(defaultEndpointURL) if err != nil { logger.Err(err).Send() - return explainResponse{}, err + return Explanations{}, err } } req, err := http.NewRequestWithContext(ctx, http.MethodPost, u.String(), bytes.NewBuffer(requestBody)) if err != nil { logger.Err(err).Str("requestBody", string(requestBody)).Msg("error creating request") - return explainResponse{}, err + return Explanations{}, err } d.addDefaultHeaders(req, requestId) @@ -57,7 +56,7 @@ func (d *DeepCodeLLMBindingImpl) runExplain(ctx context.Context, options Explain resp, err := d.httpClientFunc().Do(req) //nolint:bodyclose // this seems to be a false positive if err != nil { logger.Err(err).Str("requestBody", string(requestBody)).Msg("error getting response") - return explainResponse{}, err + return Explanations{}, err } defer func(Body io.ReadCloser) { bodyCloseErr := Body.Close() @@ -70,25 +69,28 @@ func (d *DeepCodeLLMBindingImpl) runExplain(ctx context.Context, options Explain responseBody, err := io.ReadAll(resp.Body) if err != nil { logger.Err(err).Str("requestBody", string(requestBody)).Msg("error reading all response") - return explainResponse{}, err + return Explanations{}, err } logger.Debug().Str("response body: %s\n", string(responseBody)).Msg("Got the response") - var response explainResponse + var explains Explanations response.Status = completeStatus err = json.Unmarshal(responseBody, &response) if err != nil { logger.Err(err).Str("responseBody", string(responseBody)).Msg("error unmarshalling") - return explainResponse{}, err + return Explanations{}, err } - return response, nil + + explains = response.Explanation + + return explains, nil } func (d *DeepCodeLLMBindingImpl) explainRequestBody(options *ExplainOptions) ([]byte, error) { logger := d.logger.With().Str("method", "code.explainRequestBody").Logger() var request explainRequest - if options.Diff == "" { + if len(options.Diffs) == 0 { request.VulnExplanation = &explainVulnerabilityRequest{ RuleId: options.RuleKey, Derivation: options.Derivation, @@ -99,7 +101,7 @@ func (d *DeepCodeLLMBindingImpl) explainRequestBody(options *ExplainOptions) ([] } else { request.FixExplanation = &explainFixRequest{ RuleId: options.RuleKey, - Diff: options.Diff, + Diffs: options.Diffs, ExplanationLength: SHORT, } logger.Debug().Msg("payload for FixExplanation") diff --git a/llm/api_client_test.go b/llm/api_client_test.go index 7d628f4..435e8a3 100644 --- a/llm/api_client_test.go +++ b/llm/api_client_test.go @@ -22,7 +22,7 @@ func TestDeepcodeLLMBinding_runExplain(t *testing.T) { options ExplainOptions serverResponse string serverStatusCode int - expectedResponse explainResponse + expectedResponse Explanations expectedError string expectedLogMessage string }{ @@ -33,25 +33,19 @@ func TestDeepcodeLLMBinding_runExplain(t *testing.T) { Derivation: "Derivation", RuleMessage: "rule-message", }, - serverResponse: `{"explanation": "This is a vulnerability explanation"}`, + serverResponse: "{\n \"explanation\": \n {\n \"explanation1\": \"This is the first explanation\",\n \"explanation2\": \"this is the second explanation\"\n }\n}", serverStatusCode: http.StatusOK, - expectedResponse: explainResponse{ - Status: completeStatus, - Explanation: "This is a vulnerability explanation", - }, + expectedResponse: map[string]string{"explanation1": "This is the first explanation", "explanation2": "this is the second explanation"}, }, { name: "successful fix explanation", options: ExplainOptions{ RuleKey: "rule-key", - Diff: "Diff", + Diffs: []string{"Diffs"}, }, - serverResponse: `{"explanation": "This is a fix explanation"}`, + serverResponse: "{\n \"explanation\": \n {\n \"explanation1\": \"This is the first explanation\",\n \"explanation2\": \"this is the second explanation\"\n }\n}", serverStatusCode: http.StatusOK, - expectedResponse: explainResponse{ - Status: completeStatus, - Explanation: "This is a fix explanation", - }, + expectedResponse: map[string]string{"explanation1": "This is the first explanation", "explanation2": "this is the second explanation"}, }, { name: "error creating request body", @@ -146,7 +140,7 @@ func TestDeepcodeLLMBinding_explainRequestBody(t *testing.T) { t.Run("FixExplanation", func(t *testing.T) { options := &ExplainOptions{ RuleKey: "test-rule-key", - Diff: "test-Diff", + Diffs: []string{"test-Diffs"}, } requestBody, err := d.explainRequestBody(options) require.NoError(t, err) @@ -158,7 +152,7 @@ func TestDeepcodeLLMBinding_explainRequestBody(t *testing.T) { assert.Nil(t, request.VulnExplanation) assert.NotNil(t, request.FixExplanation) assert.Equal(t, "test-rule-key", request.FixExplanation.RuleId) - assert.Equal(t, "test-Diff", request.FixExplanation.Diff) + assert.Equal(t, []string{"test-Diffs"}, request.FixExplanation.Diffs) assert.Equal(t, SHORT, request.FixExplanation.ExplanationLength) }) } diff --git a/llm/binding.go b/llm/binding.go index b79e1f5..dd5f190 100644 --- a/llm/binding.go +++ b/llm/binding.go @@ -44,10 +44,11 @@ type SnykLLMBindings interface { // output - a channel that can be used to stream the results Explain(ctx context.Context, input AIRequest, format OutputFormat, output chan<- string) error } +type ExplainResult map[string]string type DeepCodeLLMBinding interface { SnykLLMBindings - ExplainWithOptions(ctx context.Context, options ExplainOptions) (string, error) + ExplainWithOptions(ctx context.Context, options ExplainOptions) (ExplainResult, error) } // DeepCodeLLMBindingImpl is an LLM binding for the Snyk Code LLM. @@ -60,15 +61,23 @@ type DeepCodeLLMBindingImpl struct { endpoint *url.URL } -func (d *DeepCodeLLMBindingImpl) ExplainWithOptions(ctx context.Context, options ExplainOptions) (string, error) { +func (d *DeepCodeLLMBindingImpl) ExplainWithOptions(ctx context.Context, options ExplainOptions) (ExplainResult, error) { s := d.instrumentor.StartSpan(ctx, "code.ExplainWithOptions") defer d.instrumentor.Finish(s) response, err := d.runExplain(s.Context(), options) + explainResult := ExplainResult{} if err != nil { - return "", err + return explainResult, err + } + index := 0 + for _, explanation := range response { + if index < len(options.Diffs) { + explainResult[options.Diffs[index]] = explanation + } + index++ } - return response.Explanation, nil + return explainResult, nil } func (d *DeepCodeLLMBindingImpl) PublishIssues(_ context.Context, _ []map[string]string) error { @@ -85,7 +94,11 @@ func (d *DeepCodeLLMBindingImpl) Explain(ctx context.Context, input AIRequest, _ if err != nil { return err } - output <- response + jsonBytes, err := json.Marshal(response) + if err != nil { + return err + } + output <- string(jsonBytes) return nil } diff --git a/llm/binding_test.go b/llm/binding_test.go index e97ea15..0bcebf0 100644 --- a/llm/binding_test.go +++ b/llm/binding_test.go @@ -30,7 +30,7 @@ func TestExplainWithOptions(t *testing.T) { explainResponseJSON := explainResponse{ Status: completeStatus, - Explanation: "mock explanation", + Explanation: map[string]string{"explanation1": "This is the first explanation"}, } expectedResponseBody, err := json.Marshal(explainResponseJSON) @@ -41,9 +41,14 @@ func TestExplainWithOptions(t *testing.T) { Body: io.NopCloser(strings.NewReader(string(expectedResponseBody))), } mockHTTPClient.EXPECT().Do(gomock.Any()).Return(&mockResponse, nil) - explanation, err := d.ExplainWithOptions(context.Background(), ExplainOptions{}) + testDiff := "test diff" + explanation, err := d.ExplainWithOptions(context.Background(), ExplainOptions{Diffs: []string{testDiff}}) assert.NoError(t, err) - assert.Equal(t, explainResponseJSON.Explanation, explanation) + var exptectedExplanationsResponse explainResponse + err = json.Unmarshal(expectedResponseBody, &exptectedExplanationsResponse) + assert.NoError(t, err) + expectedResExplanations := exptectedExplanationsResponse.Explanation + assert.Equal(t, expectedResExplanations["explanation1"], explanation[testDiff]) }) t.Run("runExplain error", func(t *testing.T) { diff --git a/llm/types.go b/llm/types.go index daafc69..70a3bf5 100644 --- a/llm/types.go +++ b/llm/types.go @@ -9,15 +9,15 @@ const ( ) type explainVulnerabilityRequest struct { - RuleId string `json:"rule_key"` + RuleId string `json:"rule_id"` RuleMessage string `json:"rule_message"` Derivation string `json:"Derivation"` ExplanationLength explanationLength `json:"explanation_length"` } type explainFixRequest struct { - RuleId string `json:"rule_key"` - Diff string `json:"diff"` + RuleId string `json:"rule_id"` + Diffs []string `json:"diffs"` ExplanationLength explanationLength `json:"explanation_length"` } @@ -27,10 +27,10 @@ type explainRequest struct { } type explainResponse struct { - Status string `json:"status"` - Explanation string `json:"explanation"` + Status string `json:"status"` + Explanation Explanations `json:"explanation"` } - +type Explanations map[string]string type ExplainOptions struct { // Derivation = Code Flow // const derivationLineNumbers: Set = new Set(); @@ -62,5 +62,5 @@ type ExplainOptions struct { RuleMessage string `json:"rule_message"` // fix difference - Diff string `json:"diff"` + Diffs []string `json:"diffs"` }