diff --git a/.gitignore b/.gitignore index 55cd701..fa12960 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,9 @@ +.DS_Store bin/* dist/ -*.jpg \ No newline at end of file +*.jpg + +.vscode/* +*.db \ No newline at end of file diff --git a/.goreleaser.yaml b/.goreleaser.yaml index 84e522e..aca1c8a 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -1,12 +1,4 @@ -# This is an example .goreleaser.yml file with some sensible defaults. -# Make sure to check the documentation at https://goreleaser.com - -# The lines below are called `modelines`. See `:help modeline` -# Feel free to remove those if you don't want/need to use them. -# yaml-language-server: $schema=https://goreleaser.com/static/schema.json -# vim: set ts=2 sw=2 tw=0 fo=cnqoj - -version: 1 +version: 2 before: hooks: @@ -16,11 +8,17 @@ before: # - go generate ./... builds: - - env: + - id: cgo-disabled + env: - CGO_ENABLED=0 goos: - - linux - windows + - linux + + - id: cgo-enabled + env: + - CGO_ENABLED=1 + goos: - darwin archives: @@ -44,3 +42,11 @@ changelog: exclude: - "^docs:" - "^test:" + +brews: + - name: chat-cli + repository: + owner: chat-cli + name: homebrew-chat-cli + description: "chat-cli is a command line tool for working with llms on Amazon Bedrock" + homepage: "https://github.com/chat-cli/chat-cli/" \ No newline at end of file diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 0000000..d038dfa --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,23 @@ +# Read the Docs configuration file +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +# Required +version: 2 + +# Set the OS, Python version, and other tools you might need +build: + os: ubuntu-24.04 + tools: + python: "3.13" + +# Build documentation in the "docs/" directory with Sphinx +sphinx: + configuration: docs/conf.py + +# Optionally, but recommended, +# declare the Python requirements required to build your documentation +# See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html +python: + install: + - requirements: docs/requirements.txt + \ No newline at end of file diff --git a/README.md b/README.md index 40511de..1774d57 100644 --- a/README.md +++ b/README.md @@ -1,40 +1,63 @@ -# chat-cli +# 💬 chat-cli 💬 A little terminal based program that lets you interact with LLMs available via [Amazon Bedrock](https://aws.amazon.com/bedrock). +![Chat Chat Chat](docs/images/index-01.png) + ## Prerequisites 1. You will need an [AWS account](https://aws.amazon.com) -2. You will need to enable the LLMs you wish to use in Amazon Bedrock via the [Model Access](https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/modelaccess) page in the AWS Console. The defualt LLMs for both Chat and Prompt commands are proivded by Anthropic, so it is recommended to enable these as a starting point. +2. You will need to enable the LLMs you wish to use in Amazon Bedrock via the [Model Access](https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/modelaccess) page in the AWS Console. The default LLMs for both Chat and Prompt commands are provided by Anthropic, so it is recommended to enable these as a starting point. 3. You will need to install the [AWS CLI](https://docs.aws.amazon.com/cli/) tool and run `aws config` from the command line to set up credentials. ## Installation At this time you can install chat-cli via pre-packaged binaries (thanks to [GoReleaser](https://goreleaser.com/)!) for your operating system/architecture combination of choice. -1. Head to https://github.com/go-micah/chat-cli/releases/latest to find the binary for your setup. +### Pre-Built Binaries + +1. Head to https://github.com/chat-cli/chat-cli/releases/latest to find the binary for your setup. 2. Download and unzip to find a pre-compiled binary file that should work on your system. +### Homebrew + +If you have Homebrew installed on your system you can do the following two commands: + +```shell + brew tap chat-cli/chat-cli + brew install chat-cli +``` + Notes: -- You won't need Go installed on your system to use the pre-packaged binaries. -- These are currently unsigned binary files. For most systems, this will not be an issue, but on MacOS you will need to [follow these instructions](https://support.apple.com/guide/mac-help/open-a-mac-app-from-an-unidentified-developer-mh40616/mac). +- You won't need Go installed on your system to use the pre-packaged binaries or Homebrew +- These are currently unsigned binary files. For most systems, this will not be an issue, but on MacOS you will need to [follow these instructions](https://support.apple.com/guide/mac-help/open-a-mac-app-from-an-unidentified-developer-mh40616/mac). -## Build from source +### Build from source You will need [Go](https://go.dev) v1.22.1 installed on your system. You can type `go version` to ensure you have the correct version installed. -To build the project from source, clone this repo to your local machine and use [Make](https://www.gnu.org/software/make/manual/make.html) to build the binary. +To build the project from source, clone this repository to your local machine and use [Make](https://www.gnu.org/software/make/manual/make.html) to build the binary. - $ git clone git@github.com:go-micah/chat-cli.git - $ cd chat-cli - $ make +```shell + git clone git@github.com:go-micah/chat-cli.git + cd chat-cli + make +``` ## Run -To run the program from within the same directory use the following command syntax. (If you downloaded a pre-packaged binary your path will be different.) +To run the program from within the same directory use the following command syntax. + +```shell + ./bin/chat-cli +``` - $ ./bin/chat-cli +If you downloaded a pre-packaged binary or used Homebrew to install your path will be different. You can add your binary to your path (Homebrew does this for you) and then you can just do the following: + +```shell + chat-cli +``` ## Help @@ -48,21 +71,71 @@ There are currently three ways to interact with foundation models through this i 2. Start an interactive chat with an LLM using the `chat` command 3. Generate an image with the `image` command +## Configuration + +You can manage persistent configuration settings using the `config` command. This allows you to set default values for model-id and custom-arn that will be used automatically by the chat and prompt commands. + +### Setting Configuration Values + +```shell +# Set a default model ID +chat-cli config set model-id "anthropic.claude-3-5-sonnet-20240620-v1:0" + +# Set a custom ARN for marketplace or cross-region models +chat-cli config set custom-arn "arn:aws:bedrock:us-west-2::foundation-model/custom-model" +``` + +### Viewing Configuration + +```shell +# List all current configuration values +chat-cli config list +``` + +### Removing Configuration Values + +```shell +# Remove a specific configuration value +chat-cli config unset model-id +chat-cli config unset custom-arn +``` + +### Configuration Precedence + +The configuration system follows a clear precedence order: + +1. **Command line flags** (highest priority) - Values specified with `--model-id` or `--custom-arn` +2. **Configuration file** - Values set with `chat-cli config set` +3. **Built-in defaults** (lowest priority) - `anthropic.claude-3-5-sonnet-20240620-v1:0` for model-id + +**Important:** When both `model-id` and `custom-arn` are set, `custom-arn` takes precedence over `model-id`. This allows you to override the default model with a custom marketplace or cross-region model. + +### Supported Configuration Keys + +- `model-id`: The default model identifier to use for chat and prompt commands +- `custom-arn`: A custom ARN from Bedrock marketplace or for cross-region inference + ## Prompt You can send a one liner prompt like this: - $ ./bin/chat-cli prompt "How are you today?" +```shell + chat-cli prompt "How are you today?" +``` You can also read in a file from `stdin` as part of your prompt like this: - $ cat myfile.go | ./bin/chat-cli prompt "explain this code" +```shell + cat myfile.go | chat-cli prompt "explain this code" +``` - or + or - $ ./bin/chat-cli prompt "explain this code" < myfile.go +```shell + chat-cli prompt "explain this code" < myfile.go +``` -This will add `` tags arround your document ahead of your prompt. This syntax works especially well with [Anthropic Claude](https://www.anthropic.com/product). Other models may produce different results. +This will add `` tags around your document ahead of your prompt. This syntax works especially well with [Anthropic Claude](https://www.anthropic.com/product). Other models may produce different results. ## Chat @@ -70,84 +143,109 @@ You can start an interactive chat sessions which will remember your conversation You can start an interactive chat session like this: - $ ./bin/chat-cli chat +```shell + chat-cli chat +``` - Type `quit` to quit the interactive chat session. -## LLMs +### Saving and Restoring Chat Sessions -Currently all text based LLMs available through Amazon Bedrock are supported. The LLMs you wish to use must be enabled within Amazon Bedrock. +Starting a chat session with the `chat-cli chat` command will automatically save your chats to a local sqlite database. If you would like to restore a prior chat session you can do so in the following way: -The default LLM is Anthropic Claude Instant v1. +Start by using the `chat list` command to list 10 most recent chat sessions. -To switch LLMs, use the `--model-id` flag. You can supply a valid model id from the following list of currently supported models: +```shell + chat-cli chat list +``` +This will print a list that looks something like the following: -| Provider | Model ID | Family Name | Streaming Capable | Base Model | -|-----------|-------------------------------|-------------|-------------------|------------| -| Anthropic | anthropic.claude-3-haiku-20240307-v1:0 | claude3 | yes | yes | -| Anthropic | anthropic.claude-3-sonnet-20240229-v1:0 | claude3 | yes | no | -| Anthropic | anthropic.claude-v2:1 | claude | yes | | -| Anthropic | anthropic.claude-v2 | claude | yes | | -| Anthropic | anthropic.claude-instant-v1 | claude | yes | yes | -| Cohere | cohere.command-light-text-v14 | command | yes | yes | -| Cohere | cohere.command-text-v14 | command | yes | | -| Amazon | amazon.titan-text-lite-v1 | titan | not yet | yes | -| Amazon | amazon.titan-text-express-v1 | titan | not yet | | -| AI21 Labs | ai21.j2-mid-v1 | jurassic | no | yes | -| AI21 Labs | ai21.j2-ultra-v1 | jurassic | no | | -| Meta | meta.llama2-13b-chat-v1 | llama | yes | yes | -| Meta | meta.llama2-70b-chat-v1 | llama | yes | | +``` +❯ go run main.go chat list +2024-12-17T04:29:59Z | 9be2adda-5966-45c9-8a07-f7a7d486ca36 | How do I get started with AWS? +2024-12-17T04:25:53Z | 07927821-f443-4e92-84c6-86d6fa30ebf2 | What't the best way to decide which car +2024-12-17T04:23:57Z | 6ecdece8-9547-4b8b-9f36-2b92df2f84d6 | What is the best way to decide on which +2024-12-16T04:29:09Z | 879c2dd7-ba3d-4f59-a576-a1ce556ceb4e | What do you know about optics? +2024-12-16T04:28:52Z | 3a51ea83-93df-4af4-a1b3-d1ce89d845d9 | What can you tell me about electronics? -You can supply the exact model id from the list above like so: +2024-12-16T04:25:14Z | e16d52a8-83a9-4dc6-8e74-e41610689a9e | What is a Go package for printing markdo +2024-12-16T04:24:35Z | 7c4764e1-029d-4ebe-a7d6-43ef230e5117 | Can you help me write a poem about dogs? +2024-12-15T05:25:14Z | 5b2c9fb0-9ed4-4616-90be-b482bc640f8c | Can you summarize what you know about Gi +2024-12-15T05:24:04Z | 042ce5bc-a693-4e8b-9db6-eb4834b5dbac | What do you know about the Go programmin +2024-12-15T04:28:47Z | 56614689-356c-4d54-bb2c-10bd5af56b93 | How are you today? +``` + +Find the `chat-id` that corresponds to the chat session you would like to load and copy it to your clipboard. Once copied you can load that chat session like this: + +```shell + chat-cli chat --chat-id 9be2adda-5966-45c9-8a07-f7a7d486ca36 +``` + +This will print out the saved chat and leave you at a prompt where you can pick up where you left off. Future chats will continue to save with the same `chat-id` as you go. + +Please note: Eventually your chat session will result in a very large prompt context. Depending on the LLM you are using, you may get an error. Consider starting a new session when your chat session gets really lengthy! - $ ./bin/chat-cli prompt "How are you today?" --model-id cohere.command-text-v14 +## List Models -Or, you can use the `Family Name` as a shortcut. Using the Family Name will select the `Base Model` as the least expensive option offered by each provider. +You can get a list of all supported models in your current region like this: - $ ./bin/chat-cli prompt "How are you today?" --model-id titan +```shell + chat-cli models list +``` + +Please notes, this is the full list of all possible models. You will need to enable access for any models you'd like to use. + +## LLMs + +Currently all text based LLMs available through Amazon Bedrock are supported. The LLMs you wish to use must be enabled within Amazon Bedrock. + +To switch LLMs, use the `--model-id` flag. + +You can supply the exact model id from the list above like so: + +```shell + chat-cli prompt "How are you today?" --model-id cohere.command-text-v14 +``` ## Streaming Response -By default, responses will stream to the command line as they are generated. This can be dissabled using the `--no-stream` flag with the prompt command. Not all models offer a streaming response capability. +By default, responses will stream to the command line as they are generated. This can be disabled using the `--no-stream` flag with the prompt command. Not all models offer a streaming response capability. You can disable streaming like this: - $ ./bin/chat-cli prompt "What is event driven architecture?" --no-stream +```shell + chat-cli prompt "What is event driven architecture?" --no-stream +``` -Only streaming response capable models can be used with the `chat` command. +Only streaming response capable models can be used with the `chat` command. ## Model Config -There are several flags you can use to overide the default config settings. Not all config settings are used by each model. +There are several flags you can use to override the default config settings. Not all config settings are used by each model. --max-tokens defaults to 500 - --temperature defaults to 1 + --temperature defaults to 1.0 --topP defaults to 0.999 - --topK defaults to 250 -## Anthropic Claude 3 Vision +## Image Attachments -With the latest models from Anthropic, Claude 3 can now support uploading an image. Images can be either png or jpg and must be less than 5MB. To upload an image do the following: +Some LLMs support uploading an image. Images can be either png or jpg and must be less than 5MB. To upload an image do the following: - $ ./bin/chat-cli prompt "Explain this image" --image IMG_1234.JPG +```shell + chat-cli prompt "Explain this image" --image IMG_1234.JPG +``` -Please note this only works with models from Anthropic Claude 3. +Please note this only works with supported models. ## Image With the `image` command you can generate images with any supported Foundation Model. Simply follow the syntax below: - $./bin/chat-cli image "Generate an image of a cat eating cereal" - -You can specify the model with the `--model-id` flag set to model's full model id or family name. -You can also specify an output filename with the `--filename` flag. +```shell + chat-cli image "Generate an image of a cat eating cereal" +``` -## Image Models +You can specify the model with the `--model-id` flag set to model's full model id or family name. You can also specify an output filename with the `--filename` flag. -| Provider | Model ID | Family Name | Base Model | -|-----------|-------------------------------|-------------|-------------------| -| Stability AI | stability.stable-diffusion-xl-v1 | stability | yes | -| Stability AI | stability.stable-diffusion-xl-v0 | stability | | -| Amazon | amazon.titan-image-generator-v1 | titan-image |yes | \ No newline at end of file diff --git a/cmd/chat.go b/cmd/chat.go index b5c79bb..51ad657 100644 --- a/cmd/chat.go +++ b/cmd/chat.go @@ -1,22 +1,28 @@ /* -Copyright © 2024 NAME HERE +Copyright © 2024 Micah Walter */ package cmd import ( - "bufio" "context" - "encoding/json" "fmt" "log" "os" + "slices" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/bedrock" "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types" - "github.com/go-micah/chat-cli/models" - "github.com/go-micah/go-bedrock/providers" + "github.com/chat-cli/chat-cli/db" + "github.com/chat-cli/chat-cli/factory" + "github.com/chat-cli/chat-cli/repository" + "github.com/chat-cli/chat-cli/utils" + uuid "github.com/satori/go.uuid" "github.com/spf13/cobra" + + conf "github.com/chat-cli/chat-cli/config" ) // chatCmd represents the chat command @@ -29,78 +35,190 @@ To quit the chat, just type "quit" `, Run: func(cmd *cobra.Command, args []string) { - var err error - modelId, err := cmd.PersistentFlags().GetString("model-id") + fm, err := conf.NewFileManager("chat-cli") if err != nil { - log.Fatalf("unable to get flag: %v", err) + log.Fatal(err) + } + + if err := fm.InitializeViper(); err != nil { + log.Fatal(err) } - // validate model is supported - m, err := models.GetModel(modelId) + // Get SQLite database path + dbPath := fm.GetDBPath() + + // Get DBDriver from config + driver := fm.GetDBDriver() + + // get options + region, err := cmd.Parent().PersistentFlags().GetString("region") if err != nil { - log.Fatalf("error: %v", err) + log.Fatalf("unable to get flag: %v", err) } - // check if model supports streaming - if !m.SupportsStreaming { - log.Fatalf("model %s does not support streaming so it can't be used with the chat function", m.ModelID) + modelIdFlag, err := cmd.PersistentFlags().GetString("model-id") + if err != nil { + log.Fatalf("unable to get flag: %v", err) } - // get options - temperature, err := cmd.PersistentFlags().GetFloat64("temperature") + customArnFlag, err := cmd.PersistentFlags().GetString("custom-arn") if err != nil { log.Fatalf("unable to get flag: %v", err) } - topP, err := cmd.PersistentFlags().GetFloat64("topP") + // Get configuration values with precedence order (flag -> config -> default) + modelId := fm.GetConfigValue("model-id", modelIdFlag, "anthropic.claude-3-5-sonnet-20240620-v1:0").(string) + customArn := fm.GetConfigValue("custom-arn", customArnFlag, "").(string) + + // Ensure custom-arn takes precedence over model-id when both are set + // If custom-arn is set (from any source), use it; otherwise use model-id + var finalModelId string + if customArn != "" { + finalModelId = customArn + } else { + finalModelId = modelId + } + + chatId, err := cmd.PersistentFlags().GetString("chat-id") if err != nil { log.Fatalf("unable to get flag: %v", err) } - topK, err := cmd.PersistentFlags().GetFloat64("topK") + temperature, err := cmd.PersistentFlags().GetFloat32("temperature") if err != nil { log.Fatalf("unable to get flag: %v", err) } - maxTokens, err := cmd.PersistentFlags().GetInt("max-tokens") + topP, err := cmd.PersistentFlags().GetFloat32("topP") if err != nil { log.Fatalf("unable to get flag: %v", err) } - // set up connection to AWS - region, err := cmd.Parent().PersistentFlags().GetString("region") + maxTokens, err := cmd.PersistentFlags().GetInt32("max-tokens") if err != nil { log.Fatalf("unable to get flag: %v", err) } + // set up connection to AWS cfg, err := config.LoadDefaultConfig(context.TODO(), config.WithRegion(region)) if err != nil { log.Fatalf("unable to load AWS config: %v", err) } + var modelIdString string + + if customArn == "" { + // Using model-id, need to validate with Bedrock + bedrockSvc := bedrock.NewFromConfig(cfg) + + // get foundation model details + model, err := bedrockSvc.GetFoundationModel(context.TODO(), &bedrock.GetFoundationModelInput{ + ModelIdentifier: &finalModelId, + }) + if err != nil { + log.Fatalf("error: %v", err) + } + + // check if this is a text model + if !slices.Contains(model.ModelDetails.OutputModalities, "TEXT") { + log.Fatalf("model %s is not a text model, so it can't be used with the chat function", *model.ModelDetails.ModelId) + } + + // check if model supports streaming + if !*model.ModelDetails.ResponseStreamingSupported { + log.Fatalf("model %s does not support streaming so it can't be used with the chat function", *model.ModelDetails.ModelId) + } + + modelIdString = *model.ModelDetails.ModelId + } else { + // Using custom-arn, skip validation and use directly + modelIdString = finalModelId + } + svc := bedrockruntime.NewFromConfig(cfg) - var bodyString []byte - var conversation string + conf := types.InferenceConfiguration{ + MaxTokens: &maxTokens, + TopP: &topP, + Temperature: &temperature, + } - // if we are using Claude 3 and the Messages API we will need this - var messages []providers.AnthropicClaudeMessage + if chatId == "" { + chatSessionId := uuid.NewV4() + chatId = chatSessionId.String() + } - accept := "*/*" - contentType := "application/json" + metadata := map[string]string{ + "chat-session-id": chatId, + } + + converseStreamInput := &bedrockruntime.ConverseStreamInput{ + ModelId: aws.String(modelIdString), + InferenceConfig: &conf, + RequestMetadata: metadata, + } // initial prompt fmt.Printf("Hi there. You can ask me stuff!\n") + config := db.Config{ + Driver: driver, + Name: dbPath, + } + + database, err := factory.CreateDatabase(config) + if err != nil { + log.Fatalf("Failed to create database: %v", err) + } + defer database.Close() + + // Run migrations to ensure tables exist + if err := database.Migrate(); err != nil { + log.Fatalf("Failed to migrate database: %v", err) + } + + // Create repositories + chatRepo := repository.NewChatRepository(database) + + // load saved conversation + if chatId != "" { + if chats, err := chatRepo.GetMessages(chatId); err != nil { + log.Printf("Failed to load messages: %v", err) + } else { + for _, chat := range chats { + if chat.Persona == "User" { + fmt.Printf("[User]: %s\n", chat.Message) + userMsg := types.Message{ + Role: types.ConversationRoleUser, + Content: []types.ContentBlock{ + &types.ContentBlockMemberText{ + Value: chat.Message, + }, + }, + } + converseStreamInput.Messages = append(converseStreamInput.Messages, userMsg) + } else { + fmt.Printf("[Assistant]: %s\n", chat.Message) + assistantMsg := types.Message{ + Role: types.ConversationRoleAssistant, + Content: []types.ContentBlock{ + &types.ContentBlockMemberText{ + Value: chat.Message, + }, + }, + } + converseStreamInput.Messages = append(converseStreamInput.Messages, assistantMsg) + } + } + } + } + // tty-loop for { - // stores response chunks as one string - var chunks string - // gets user input - prompt := stringPrompt(">") + prompt := utils.StringPrompt(">") // check for special words @@ -109,299 +227,73 @@ To quit the chat, just type "quit" os.Exit(0) } - // serialize body - switch m.ModelFamily { - case "claude3": - - textPrompt := providers.AnthropicClaudeContent{ - Type: "text", - Text: prompt, - } - - message := providers.AnthropicClaudeMessage{ - Role: "user", - Content: []providers.AnthropicClaudeContent{ - textPrompt, + userMsg := types.Message{ + Role: types.ConversationRoleUser, + Content: []types.ContentBlock{ + &types.ContentBlockMemberText{ + Value: prompt, }, - } - - messages = append(messages, message) - - body := providers.AnthropicClaudeMessagesInvokeModelInput{ - Messages: messages, - MaxTokens: maxTokens, - TopP: topP, - TopK: int(topK), - Temperature: temperature, - StopSequences: []string{}, - } + }, + } - bodyString, err = json.Marshal(body) - if err != nil { - log.Fatalf("unable to marshal body: %v", err) - } + converseStreamInput.Messages = append(converseStreamInput.Messages, userMsg) - case "claude": - conversation = conversation + " \\n\\nHuman: " + prompt - - body := providers.AnthropicClaudeInvokeModelInput{ - Prompt: "Human: \n\nHuman: " + conversation + "\n\nAssistant:", - MaxTokensToSample: maxTokens, - Temperature: temperature, - TopK: int(topK), - TopP: topP, - StopSequences: []string{ - "\n\nHuman:", - }, - } + output, err := svc.ConverseStream(context.Background(), converseStreamInput) - bodyString, err = json.Marshal(body) - if err != nil { - log.Fatalf("unable to marshal body: %v", err) - } - case "command": - conversation = conversation + "\\n\\n" + prompt - - body := providers.CohereCommandInvokeModelInput{ - Prompt: conversation, - Temperature: temperature, - TopP: topP, - TopK: topK, - MaxTokensToSample: maxTokens, - StopSequences: []string{`""`}, - ReturnLiklihoods: "NONE", - NumGenerations: 1, - } - bodyString, err = json.Marshal(body) - if err != nil { - log.Fatalf("unable to marshal body: %v", err) - } - case "llama": - conversation = conversation + "\\n\\n" + prompt - - body := providers.MetaLlamaInvokeModelInput{ - Prompt: prompt, - Temperature: temperature, - TopP: topP, - MaxTokensToSample: maxTokens, - } - bodyString, err = json.Marshal(body) - if err != nil { - log.Fatalf("unable to marshal body: %v", err) - } - default: - log.Fatalf("invalid model: %s", m.ModelID) - } - - // invoke with streaming response - resp, err := svc.InvokeModelWithResponseStream(context.TODO(), &bedrockruntime.InvokeModelWithResponseStreamInput{ - Accept: &accept, - ModelId: &m.ModelID, - ContentType: &contentType, - Body: bodyString, - }) if err != nil { - log.Fatalf("error from Bedrock, %v", err) + log.Fatal(err) } - // print streaming response - switch m.ModelFamily { - case "claude3": - var out providers.AnthropicClaudeMessagesInvokeModelOutput - - stream := resp.GetStream().Reader - events := stream.Events() - - for { - event := <-events - if event != nil { - if v, ok := event.(*types.ResponseStreamMemberChunk); ok { - // v has fields - err := json.Unmarshal([]byte(v.Value.Bytes), &out) - if err != nil { - log.Printf("unable to decode response:, %v", err) - continue - } - if out.Type == "content_block_delta" { - fmt.Printf("%v", out.Delta.Text) - chunks = chunks + out.Delta.Text - } - } else if v, ok := event.(*types.UnknownUnionMember); ok { - // catchall - fmt.Print(v.Value) - } - } else { - break - } - } - stream.Close() - - if stream.Err() != nil { - log.Fatalf("error from Bedrock, %v", stream.Err()) - } - fmt.Println() - - textPrompt := providers.AnthropicClaudeContent{ - Type: "text", - Text: chunks, - } - - message := providers.AnthropicClaudeMessage{ - Role: "assistant", - Content: []providers.AnthropicClaudeContent{ - textPrompt, - }, - } - - messages = append(messages, message) - - case "claude": - var out providers.AnthropicClaudeInvokeModelOutput - - stream := resp.GetStream().Reader - events := stream.Events() - - for { - event := <-events - if event != nil { - if v, ok := event.(*types.ResponseStreamMemberChunk); ok { - // v has fields - err := json.Unmarshal([]byte(v.Value.Bytes), &out) - if err != nil { - log.Printf("unable to decode response:, %v", err) - continue - } - fmt.Printf("%v", out.Completion) - chunks = chunks + out.Completion - } else if v, ok := event.(*types.UnknownUnionMember); ok { - // catchall - fmt.Print(v.Value) - } - } else { - break - } - } - stream.Close() - - if stream.Err() != nil { - log.Fatalf("error from Bedrock, %v", stream.Err()) - } - fmt.Println() - - conversation = conversation + " \\n\\nAssistant: " + chunks + // Use the repository without knowing the underlying database type + chat := &repository.Chat{ + ChatId: chatId, + Persona: "User", + Message: prompt, + } - case "command": + if err := chatRepo.Create(chat); err != nil { + log.Printf("Failed to create chat: %v", err) + } - var out providers.CohereCommandInvokeModelOutput + fmt.Print("[Assistant]: ") - stream := resp.GetStream().Reader - events := stream.Events() + var out string - for { - event := <-events - if event != nil { - if v, ok := event.(*types.ResponseStreamMemberChunk); ok { - // v has fields - err := json.Unmarshal([]byte(v.Value.Bytes), &out) - if err != nil { - log.Printf("unable to decode response:, %v", err) - continue - } - fmt.Printf("%v", out.Generations[0].Text) - chunks = chunks + out.Generations[0].Text + assistantMsg, err := utils.ProcessStreamingOutput(output, func(ctx context.Context, part string) error { + fmt.Print(part) + out += part + return nil + }) - } else if v, ok := event.(*types.UnknownUnionMember); ok { - // catchall - fmt.Print(v.Value) - } - } else { - break - } - } - stream.Close() + if err != nil { + log.Fatal("streaming output processing error: ", err) + } - if stream.Err() != nil { - log.Fatalf("error from Bedrock, %v", stream.Err()) - } - fmt.Println() - - conversation = conversation + "\\n\\n " + chunks - - case "llama": - var out providers.MetaLlamaInvokeModelOutput - - stream := resp.GetStream().Reader - events := stream.Events() - - for { - event := <-events - if event != nil { - if v, ok := event.(*types.ResponseStreamMemberChunk); ok { - // v has fields - err := json.Unmarshal([]byte(v.Value.Bytes), &out) - if err != nil { - log.Printf("unable to decode response:, %v", err) - continue - } - fmt.Printf("%v", out.Generation) - chunks = chunks + out.Generation - - } else if v, ok := event.(*types.UnknownUnionMember); ok { - // catchall - fmt.Print(v.Value) - } - } else { - break - } - } - stream.Close() + converseStreamInput.Messages = append(converseStreamInput.Messages, assistantMsg) - if stream.Err() != nil { - log.Fatalf("error from Bedrock, %v", stream.Err()) - } - fmt.Println() - conversation = conversation + "\\n\\n " + chunks + chat = &repository.Chat{ + ChatId: chatId, + Persona: "Assistant", + Message: out, + } - default: - log.Fatalf("invalid model: %s", m.ModelID) + if err := chatRepo.Create(chat); err != nil { + log.Printf("Failed to create chat: %v", err) } + fmt.Println() + } }, } func init() { rootCmd.AddCommand(chatCmd) + chatCmd.PersistentFlags().StringP("model-id", "m", "anthropic.claude-3-5-sonnet-20240620-v1:0", "set the model id") + chatCmd.PersistentFlags().String("custom-arn", "", "pass a custom arn from bedrock marketplace or cross-region inference") + chatCmd.PersistentFlags().String("chat-id", "", "pass a valid chat-id to load a previous conversation") - // Here you will define your flags and configuration settings. - - // Cobra supports Persistent Flags which will work for this command - // and all subcommands, e.g.: - // chatCmd.PersistentFlags().String("foo", "", "A help for foo") - chatCmd.PersistentFlags().StringP("model-id", "m", "anthropic.claude-3-haiku-20240307-v1:0", "set the model id") - - chatCmd.PersistentFlags().Float64("temperature", 1, "temperature setting") - chatCmd.PersistentFlags().Float64("topP", 0.999, "topP setting") - chatCmd.PersistentFlags().Float64("topK", 250, "topK setting") - chatCmd.PersistentFlags().Int("max-tokens", 500, "max tokens to sample") - - // Cobra supports local flags which will only run when this command - // is called directly, e.g.: - // chatCmd.Flags().BoolP("toggle", "t", false, "Help message for toggle") -} - -func stringPrompt(label string) string { - - var s string - r := bufio.NewReader(os.Stdin) - - for { - fmt.Fprint(os.Stderr, label+" ") - s, _ = r.ReadString('\n') - if s != "" { - break - } - } - - return s + chatCmd.PersistentFlags().Float32("temperature", 1.0, "temperature setting") + chatCmd.PersistentFlags().Float32("topP", 0.999, "topP setting") + chatCmd.PersistentFlags().Int32("max-tokens", 500, "max tokens") } diff --git a/cmd/chatList.go b/cmd/chatList.go new file mode 100644 index 0000000..fb8fcca --- /dev/null +++ b/cmd/chatList.go @@ -0,0 +1,89 @@ +/* +Copyright © 2024 Micah Walter +*/ +package cmd + +import ( + "fmt" + "log" + "os" + "text/tabwriter" + + conf "github.com/chat-cli/chat-cli/config" + "github.com/chat-cli/chat-cli/db" + "github.com/chat-cli/chat-cli/factory" + "github.com/chat-cli/chat-cli/repository" + "github.com/spf13/cobra" +) + +// chatListCmd represents the chatList command +var chatListCmd = &cobra.Command{ + Use: "list", + Short: "Prints a list of recent chats and IDs", + + Run: func(cmd *cobra.Command, args []string) { + + fm, err := conf.NewFileManager("chat-cli") + if err != nil { + log.Fatal(err) + } + + if err := fm.InitializeViper(); err != nil { + log.Fatal(err) + } + + // Get SQLite database path + dbPath := fm.GetDBPath() + + // Get the database driver from the configuration + driver := fm.GetDBDriver() + + config := db.Config{ + Driver: driver, + Name: dbPath, + } + + database, err := factory.CreateDatabase(config) + if err != nil { + log.Fatalf("Failed to create database: %v", err) + } + defer database.Close() + + // Run migrations to ensure tables exist + if err := database.Migrate(); err != nil { + log.Fatalf("Failed to migrate database: %v", err) + } + + // Create repositories + chatRepo := repository.NewChatRepository(database) + + if chats, err := chatRepo.List(); err != nil { + log.Printf("Failed to create chat: %v", err) + } else { + fmt.Println("") + + // Create a new tabwriter + w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) + + // Print the header + fmt.Fprintln(w, "Created Date\t Chat ID\t Title") + + fmt.Fprintln(w, "\t\t") + + for _, chat := range chats { + fmt.Fprintf(w, "%s\t %s\t %s\n", chat.Created, chat.ChatId, truncate(chat.Message, 40)) + } + } + }, +} + +func init() { + chatCmd.AddCommand(chatListCmd) +} + +func truncate(s string, length int) string { + if len(s) <= length { + return s + } + return s[:length] + "\n" +} diff --git a/cmd/config.go b/cmd/config.go new file mode 100644 index 0000000..8a76155 --- /dev/null +++ b/cmd/config.go @@ -0,0 +1,180 @@ +/* +Copyright © 2024 Micah Walter +*/ +package cmd + +import ( + "fmt" + "log" + "os" + "path/filepath" + + "github.com/spf13/cobra" + "github.com/spf13/viper" + "gopkg.in/yaml.v3" + + conf "github.com/chat-cli/chat-cli/config" +) + +// configCmd represents the config command +var configCmd = &cobra.Command{ + Use: "config", + Short: "Manage configuration settings", + Long: `Manage configuration settings for chat-cli. You can set, unset, and list configuration values.`, +} + +// configSetCmd represents the config set command +var configSetCmd = &cobra.Command{ + Use: "set ", + Short: "Set a configuration value", + Long: `Set a configuration value. Supported keys: custom-arn, model-id`, + Args: cobra.ExactArgs(2), + Run: func(cmd *cobra.Command, args []string) { + // Initialize configuration + fm, err := conf.NewFileManager("chat-cli") + if err != nil { + log.Fatal(err) + } + + if err := fm.InitializeViper(); err != nil { + log.Fatal(err) + } + + key := args[0] + value := args[1] + + // Validate supported keys + supportedKeys := map[string]bool{ + "custom-arn": true, + "model-id": true, + } + + if !supportedKeys[key] { + fmt.Printf("Error: unsupported configuration key '%s'\n", key) + fmt.Println("Supported keys: custom-arn, model-id") + os.Exit(1) + } + + // Set the configuration value + viper.Set(key, value) + + // Write the configuration to file + if err := viper.WriteConfig(); err != nil { + fmt.Printf("Error writing config: %v\n", err) + os.Exit(1) + } + + fmt.Printf("Configuration set: %s = %s\n", key, value) + }, +} + +// configUnsetCmd represents the config unset command +var configUnsetCmd = &cobra.Command{ + Use: "unset ", + Short: "Unset a configuration value", + Long: `Unset (remove) a configuration value. Supported keys: custom-arn, model-id`, + Args: cobra.ExactArgs(1), + Run: func(cmd *cobra.Command, args []string) { + // Initialize configuration + fm, err := conf.NewFileManager("chat-cli") + if err != nil { + log.Fatal(err) + } + + if err := fm.InitializeViper(); err != nil { + log.Fatal(err) + } + + key := args[0] + + // Validate supported keys + supportedKeys := map[string]bool{ + "custom-arn": true, + "model-id": true, + } + + if !supportedKeys[key] { + fmt.Printf("Error: unsupported configuration key '%s'\n", key) + fmt.Println("Supported keys: custom-arn, model-id") + os.Exit(1) + } + + // Check if the key exists + if !viper.IsSet(key) { + fmt.Printf("Configuration key '%s' is not set\n", key) + return + } + + // Get config file path + configPath := filepath.Join(fm.ConfigPath, fm.ConfigFile) + + // Read current config + var configData map[string]interface{} + if configFile, err := os.ReadFile(configPath); err == nil { + yaml.Unmarshal(configFile, &configData) + } + + if configData == nil { + configData = make(map[string]interface{}) + } + + // Remove the key + delete(configData, key) + + // Write back to file + yamlData, err := yaml.Marshal(configData) + if err != nil { + fmt.Printf("Error marshaling config: %v\n", err) + os.Exit(1) + } + + if err := os.WriteFile(configPath, yamlData, 0644); err != nil { + fmt.Printf("Error writing config file: %v\n", err) + os.Exit(1) + } + + fmt.Printf("Configuration unset: %s\n", key) + }, +} + +// configListCmd represents the config list command +var configListCmd = &cobra.Command{ + Use: "list", + Short: "List all configuration values", + Long: `List all current configuration values.`, + Run: func(cmd *cobra.Command, args []string) { + // Initialize configuration + fm, err := conf.NewFileManager("chat-cli") + if err != nil { + log.Fatal(err) + } + + if err := fm.InitializeViper(); err != nil { + log.Fatal(err) + } + + fmt.Println("Current configuration:") + + // Define the keys we care about + configKeys := []string{"custom-arn", "model-id"} + + hasConfig := false + for _, key := range configKeys { + if viper.IsSet(key) { + fmt.Printf(" %s = %s\n", key, viper.GetString(key)) + hasConfig = true + } + } + + if !hasConfig { + fmt.Println(" No configuration values set") + } + }, +} + +func init() { + rootCmd.AddCommand(configCmd) + configCmd.AddCommand(configSetCmd) + configCmd.AddCommand(configUnsetCmd) + configCmd.AddCommand(configListCmd) +} \ No newline at end of file diff --git a/cmd/image.go b/cmd/image.go index 5e69de5..a229d4d 100644 --- a/cmd/image.go +++ b/cmd/image.go @@ -1,23 +1,22 @@ /* -Copyright © 2024 NAME HERE +Copyright © 2024 Micah Walter */ package cmd import ( "context" - "encoding/base64" "encoding/json" "fmt" - "io" "log" "os" + "slices" "time" "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/bedrock" "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" - "github.com/go-micah/chat-cli/models" + "github.com/chat-cli/chat-cli/utils" "github.com/go-micah/go-bedrock/providers" - "github.com/mattn/go-isatty" "github.com/spf13/cobra" ) @@ -32,45 +31,42 @@ var imageCmd = &cobra.Command{ prompt := args[0] - // read a document from stdin - var document string - - if isatty.IsTerminal(os.Stdin.Fd()) || isatty.IsCygwinTerminal(os.Stdin.Fd()) { - // do nothing - } else { - stdin, err := io.ReadAll(os.Stdin) - - if err != nil { - panic(err) - } - document = string(stdin) - } - - if document != "" { - document = "\n\n" + document + "\n\n\n\n" - prompt = document + prompt - } + document, err := utils.LoadDocument() + prompt = prompt + document accept := "*/*" contentType := "application/json" var bodyString []byte - var err error + + // set up connection to AWS + region, err := cmd.Parent().PersistentFlags().GetString("region") + if err != nil { + log.Fatalf("unable to get flag: %v", err) + } + + cfg, err := config.LoadDefaultConfig(context.TODO(), config.WithRegion(region)) + if err != nil { + log.Fatalf("unable to load AWS config: %v", err) + } modelId, err := cmd.PersistentFlags().GetString("model-id") if err != nil { log.Fatalf("unable to get flag: %v", err) } - // validate model is supported - m, err := models.GetModel(modelId) + bedrockSvc := bedrock.NewFromConfig(cfg) + + model, err := bedrockSvc.GetFoundationModel(context.TODO(), &bedrock.GetFoundationModelInput{ + ModelIdentifier: &modelId, + }) if err != nil { log.Fatalf("error: %v", err) } // validate model supports image generation - if m.ModelType != "image" { - log.Fatalf("model %s does not support image generation. please use a different model", m.ModelID) + if !slices.Contains(model.ModelDetails.OutputModalities, "IMAGE") { + log.Fatalf("model %s does not support image generation. please use a different model", *model.ModelDetails.ModelId) } // get options @@ -95,8 +91,8 @@ var imageCmd = &cobra.Command{ } // serialize body - switch m.ModelFamily { - case "stability": + switch *model.ModelDetails.ProviderName { + case "Stability AI": body := providers.StabilityAIStableDiffusionInvokeModelInput{ Prompt: []providers.StabilityAIStableDiffusionTextPrompt{ { @@ -112,7 +108,7 @@ var imageCmd = &cobra.Command{ if err != nil { log.Fatalf("unable to marshal body: %v", err) } - case "titan-image": + case "Amazon": body := providers.AmazonTitanImageInvokeModelInput{ TaskType: "TEXT_IMAGE", TextToImageParams: providers.AmazonTitanImageInvokeModelInputTextToImageParams{ @@ -130,25 +126,14 @@ var imageCmd = &cobra.Command{ log.Fatalf("unable to marshal body: %v", err) } default: - log.Fatalf("invalid model: %s", m.ModelID) - } - - // set up connection to AWS - region, err := cmd.Parent().PersistentFlags().GetString("region") - if err != nil { - log.Fatalf("unable to get flag: %v", err) - } - - cfg, err := config.LoadDefaultConfig(context.TODO(), config.WithRegion(region)) - if err != nil { - log.Fatalf("unable to load AWS config: %v", err) + log.Fatalf("invalid model: %s", *model.ModelDetails.ModelId) } svc := bedrockruntime.NewFromConfig(cfg) resp, err := svc.InvokeModel(context.TODO(), &bedrockruntime.InvokeModelInput{ Accept: &accept, - ModelId: &m.ModelID, + ModelId: model.ModelDetails.ModelId, ContentType: &contentType, Body: bodyString, }) @@ -157,8 +142,8 @@ var imageCmd = &cobra.Command{ } // save images to disk - switch m.ModelFamily { - case "stability": + switch *model.ModelDetails.ProviderName { + case "Stability AI": var out providers.StabilityAIStableDiffusionInvokeModelOutput err = json.Unmarshal(resp.Body, &out) @@ -166,12 +151,12 @@ var imageCmd = &cobra.Command{ log.Fatalf("unable to unmarshal response from Bedrock: %v", err) } - decoded, err := decodeImage(out.Artifacts[0].Base64) + decoded, err := utils.DecodeImage(out.Artifacts[0].Base64) if err != nil { log.Fatalf("unable to decode image: %v", err) } - outputFile := fmt.Sprintf("%s-%d.jpg", m.ModelFamily, time.Now().Unix()) + outputFile := fmt.Sprintf("%d.jpg", time.Now().Unix()) // if we have a filename set, us it instead if filename != "" { @@ -184,7 +169,7 @@ var imageCmd = &cobra.Command{ } log.Println("image written to file", outputFile) - case "titan-image": + case "Amazon": var out providers.AmazonTitanImageInvokeModelOutput err = json.Unmarshal(resp.Body, &out) @@ -192,12 +177,12 @@ var imageCmd = &cobra.Command{ log.Fatalf("unable to unmarshal response from Bedrock: %v", err) } - decoded, err := decodeImage(out.Images[0]) + decoded, err := utils.DecodeImage(out.Images[0]) if err != nil { log.Fatalf("unable to decode image: %v", err) } - outputFile := fmt.Sprintf("%s-%d.jpg", m.ModelFamily, time.Now().Unix()) + outputFile := fmt.Sprintf("%d.jpg", time.Now().Unix()) // if we have a filename set, us it instead if filename != "" { @@ -228,15 +213,7 @@ func init() { // Cobra supports local flags which will only run when this command // is called directly, e.g.: // imageCmd.Flags().BoolP("toggle", "t", false, "Help message for toggle") - imageCmd.PersistentFlags().StringP("model-id", "m", "stability.stable-diffusion-xl-v1", "set the model id") + imageCmd.PersistentFlags().StringP("model-id", "m", "amazon.nova-canvas-v1:0", "set the model id") imageCmd.PersistentFlags().StringP("filename", "f", "", "provide an output filename") } - -func decodeImage(base64Image string) ([]byte, error) { - decoded, err := base64.StdEncoding.DecodeString(base64Image) - if err != nil { - return nil, err - } - return decoded, nil -} diff --git a/cmd/models.go b/cmd/models.go new file mode 100644 index 0000000..752a166 --- /dev/null +++ b/cmd/models.go @@ -0,0 +1,23 @@ +/* +Copyright © 2024 Micah Walter +*/ +package cmd + +import ( + "fmt" + + "github.com/spf13/cobra" +) + +// modelsCmd represents the models command +var modelsCmd = &cobra.Command{ + Use: "models", + Short: "Configure and list available models", + Run: func(cmd *cobra.Command, args []string) { + fmt.Println("models called") + }, +} + +func init() { + rootCmd.AddCommand(modelsCmd) +} diff --git a/cmd/modelsList.go b/cmd/modelsList.go new file mode 100644 index 0000000..a06efc8 --- /dev/null +++ b/cmd/modelsList.go @@ -0,0 +1,67 @@ +/* +Copyright © 2024 Micah Walter +*/ +package cmd + +import ( + "context" + "fmt" + "os" + "text/tabwriter" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/bedrock" + "github.com/spf13/cobra" +) + +// modelsListCmd represents the list command +var modelsListCmd = &cobra.Command{ + Use: "list", + Short: "List all available models", + + Run: func(cmd *cobra.Command, args []string) { + listModels() + }, +} + +func init() { + modelsCmd.AddCommand(modelsListCmd) +} + +func listModels() { + // Load the default configuration + cfg, err := config.LoadDefaultConfig(context.TODO(), config.WithRegion("us-east-1")) + if err != nil { + fmt.Println("Error loading configuration:", err) + return + } + + // Create a new Bedrock client + svc := bedrock.NewFromConfig(cfg) + + // Call the ListModels API + result, err := svc.ListFoundationModels(context.TODO(), &bedrock.ListFoundationModelsInput{}) + if err != nil { + fmt.Println("Error listing models:", err) + return + } + + fmt.Println("") + + // Create a new tabwriter + w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) + + // Print the header + fmt.Fprintln(w, "Provider\t Name\t Model ID") + + fmt.Fprintln(w, "\t\t") + + // Print the models + for _, model := range result.ModelSummaries { + fmt.Fprintf(w, "%s\t %s\t %s\n", aws.ToString(model.ProviderName), aws.ToString(model.ModelName), aws.ToString(model.ModelId)) + } + + // Flush the writer + w.Flush() +} diff --git a/cmd/prompt.go b/cmd/prompt.go index 488f30c..43bc746 100644 --- a/cmd/prompt.go +++ b/cmd/prompt.go @@ -1,28 +1,22 @@ /* -Copyright © 2024 NAME HERE +Copyright © 2024 Micah Walter */ package cmd import ( - "bytes" "context" - "encoding/base64" - "encoding/json" "fmt" - "image/jpeg" - "image/png" - "io" "log" - "net/http" - "os" + "slices" "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/bedrock" "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types" - "github.com/go-micah/chat-cli/models" - "github.com/go-micah/go-bedrock/providers" - "github.com/mattn/go-isatty" + "github.com/chat-cli/chat-cli/utils" "github.com/spf13/cobra" + + conf "github.com/chat-cli/chat-cli/config" ) // promptCmd represents the prompt command @@ -37,492 +31,206 @@ var promptCmd = &cobra.Command{ prompt := args[0] - // read a document from stdin - var document string - - if isatty.IsTerminal(os.Stdin.Fd()) || isatty.IsCygwinTerminal(os.Stdin.Fd()) { - // do nothing - } else { - stdin, err := io.ReadAll(os.Stdin) + document, err := utils.LoadDocument() + prompt = prompt + document - if err != nil { - panic(err) - } - document = string(stdin) + // Initialize configuration + fm, err := conf.NewFileManager("chat-cli") + if err != nil { + log.Fatal(err) } - if document != "" { - document = "\n\n" + document + "\n\n\n\n" - prompt = document + prompt + if err := fm.InitializeViper(); err != nil { + log.Fatal(err) } - accept := "*/*" - contentType := "application/json" - - var bodyString []byte - var err error - - modelId, err := cmd.PersistentFlags().GetString("model-id") + // set up connection to AWS + region, err := cmd.Parent().PersistentFlags().GetString("region") if err != nil { log.Fatalf("unable to get flag: %v", err) } - // validate model is supported - m, err := models.GetModel(modelId) - if err != nil { - log.Fatalf("error: %v", err) - } - - // get options - temperature, err := cmd.PersistentFlags().GetFloat64("temperature") + cfg, err := config.LoadDefaultConfig(context.TODO(), config.WithRegion(region)) if err != nil { - log.Fatalf("unable to get flag: %v", err) + log.Fatalf("unable to load AWS config: %v", err) } - topP, err := cmd.PersistentFlags().GetFloat64("topP") + modelIdFlag, err := cmd.PersistentFlags().GetString("model-id") if err != nil { log.Fatalf("unable to get flag: %v", err) } - topK, err := cmd.PersistentFlags().GetFloat64("topK") + // get feature flag for image attachment + image, err := cmd.PersistentFlags().GetString("image") if err != nil { log.Fatalf("unable to get flag: %v", err) } - maxTokens, err := cmd.PersistentFlags().GetInt("max-tokens") + // check if --no-stream is set + noStream, err := cmd.PersistentFlags().GetBool("no-stream") if err != nil { log.Fatalf("unable to get flag: %v", err) } - image, err := cmd.PersistentFlags().GetString("image") + customArnFlag, err := cmd.PersistentFlags().GetString("custom-arn") if err != nil { log.Fatalf("unable to get flag: %v", err) } - var encodedImage string - var mimeType string - var imagePrompt providers.AnthropicClaudeContent + // Get configuration values with precedence order (flag -> config -> default) + modelId := fm.GetConfigValue("model-id", modelIdFlag, "anthropic.claude-3-5-sonnet-20240620-v1:0").(string) + customArn := fm.GetConfigValue("custom-arn", customArnFlag, "").(string) - if (image != "") && (m.ModelFamily != "claude3") { - log.Fatalf("model %s does not support vision. please use a different model", m.ModelID) + // Ensure custom-arn takes precedence over model-id when both are set + // If custom-arn is set (from any source), use it; otherwise use model-id + var finalModelId string + if customArn != "" { + finalModelId = customArn + } else { + finalModelId = modelId } - // serialize body - switch m.ModelFamily { - case "claude3": - textPrompt := providers.AnthropicClaudeContent{ - Type: "text", - Text: prompt, - } - - content := []providers.AnthropicClaudeContent{ - textPrompt, - } + var modelIdString string - if image != "" { - encodedImage, mimeType, err = readImage(image) - if err != nil { - log.Fatalf("unable to read image: %v", err) - } - imagePrompt = providers.AnthropicClaudeContent{ - Type: "image", - Source: &providers.AnthropicClaudeSource{ - Type: "base64", - MediaType: mimeType, - Data: encodedImage, - }, - } + bedrockSvc := bedrock.NewFromConfig(cfg) - content = append(content, imagePrompt) + if customArn == "" { + // Using model-id, need to validate with Bedrock + model, err := bedrockSvc.GetFoundationModel(context.TODO(), &bedrock.GetFoundationModelInput{ + ModelIdentifier: &finalModelId, + }) + if err != nil { + log.Fatalf("error: %v", err) } - body := providers.AnthropicClaudeMessagesInvokeModelInput{ - Messages: []providers.AnthropicClaudeMessage{ - { - Role: "user", - Content: content, - }, - }, - MaxTokens: maxTokens, - TopP: topP, - TopK: int(topK), - Temperature: temperature, - StopSequences: []string{}, + // check if this is a text model + if !slices.Contains(model.ModelDetails.OutputModalities, "TEXT") { + log.Fatalf("model %s is not a text model, so it can't be used with the chat function", *model.ModelDetails.ModelId) } - bodyString, err = json.Marshal(body) - if err != nil { - log.Fatalf("unable to marshal body: %v", err) - } - case "claude": - body := providers.AnthropicClaudeInvokeModelInput{ - Prompt: "Human: \n\nHuman: " + prompt + "\n\nAssistant:", - MaxTokensToSample: maxTokens, - Temperature: temperature, - TopK: int(topK), - TopP: topP, - StopSequences: []string{ - "\n\nHuman:", - }, + // check if model supports image/vision capabilities + if (image != "") && (!slices.Contains(model.ModelDetails.InputModalities, "IMAGE")) { + log.Fatalf("model %s does not support images as input. please use a different model", *model.ModelDetails.ModelId) } - bodyString, err = json.Marshal(body) - if err != nil { - log.Fatalf("unable to marshal body: %v", err) - } - case "jurassic": - body := providers.AI21LabsJurassicInvokeModelInput{ - Prompt: prompt, - Temperature: temperature, - TopP: topP, - MaxTokensToSample: maxTokens, - StopSequences: []string{`""`}, - } - bodyString, err = json.Marshal(body) - if err != nil { - log.Fatalf("unable to marshal body: %v", err) - } - case "command": - body := providers.CohereCommandInvokeModelInput{ - Prompt: prompt, - Temperature: temperature, - TopP: topP, - TopK: topK, - MaxTokensToSample: maxTokens, - StopSequences: []string{`""`}, - ReturnLiklihoods: "NONE", - NumGenerations: 1, - } - bodyString, err = json.Marshal(body) - if err != nil { - log.Fatalf("unable to marshal body: %v", err) - } - case "llama": - body := providers.MetaLlamaInvokeModelInput{ - Prompt: prompt, - Temperature: temperature, - TopP: topP, - MaxTokensToSample: maxTokens, - } - bodyString, err = json.Marshal(body) - if err != nil { - log.Fatalf("unable to marshal body: %v", err) - } - case "titan": - config := providers.AmazonTitanTextGenerationConfig{ - Temperature: temperature, - TopP: topP, - MaxTokensToSample: maxTokens, - StopSequences: []string{ - "User:", - }, + // check if model supports streaming and --no-stream is not set + if (!noStream) && (!*model.ModelDetails.ResponseStreamingSupported) { + log.Fatalf("model %s does not support streaming. please use the --no-stream flag", *model.ModelDetails.ModelId) } - body := providers.AmazonTitanTextInvokeModelInput{ - Prompt: prompt, - Config: config, - } - bodyString, err = json.Marshal(body) - if err != nil { - log.Fatalf("unable to marshal body: %v", err) - } - default: - log.Fatalf("invalid model: %s", m.ModelID) + modelIdString = *model.ModelDetails.ModelId + } else { + // Using custom-arn, skip validation and use directly + modelIdString = finalModelId } - // set up connection to AWS - region, err := cmd.Parent().PersistentFlags().GetString("region") + // get options + temperature, err := cmd.PersistentFlags().GetFloat32("temperature") if err != nil { log.Fatalf("unable to get flag: %v", err) } - cfg, err := config.LoadDefaultConfig(context.TODO(), config.WithRegion(region)) + topP, err := cmd.PersistentFlags().GetFloat32("topP") if err != nil { - log.Fatalf("unable to load AWS config: %v", err) + log.Fatalf("unable to get flag: %v", err) } - svc := bedrockruntime.NewFromConfig(cfg) - - // check if --no-stream is set - noStream, err := cmd.PersistentFlags().GetBool("no-stream") + maxTokens, err := cmd.PersistentFlags().GetInt32("max-tokens") if err != nil { log.Fatalf("unable to get flag: %v", err) } - // check if model supports streaming and --no-stream is not set - if (!noStream) && (!m.SupportsStreaming) { - log.Fatalf("model %s does not support streaming. please use the --no-stream flag", m.ModelID) + svc := bedrockruntime.NewFromConfig(cfg) + + // craft prompt + userMsg := types.Message{ + Role: types.ConversationRoleUser, + Content: []types.ContentBlock{ + &types.ContentBlockMemberText{ + Value: prompt, + }, + }, + } + + // attach image if we have one + if image != "" { + imageBytes, imageType, err := utils.ReadImage(image) + if err != nil { + log.Fatalf("unable to read image: %v", err) + } + + userMsg.Content = append(userMsg.Content, &types.ContentBlockMemberImage{ + Value: types.ImageBlock{ + Format: types.ImageFormat(imageType), + Source: &types.ImageSourceMemberBytes{ + Value: imageBytes, + }, + }, + }) + + } + + conf := types.InferenceConfiguration{ + MaxTokens: &maxTokens, + TopP: &topP, + Temperature: &temperature, } if noStream { + // set up ConverseInput with model and prompt + converseInput := &bedrockruntime.ConverseInput{ + ModelId: &modelIdString, + InferenceConfig: &conf, + } + converseInput.Messages = append(converseInput.Messages, userMsg) + // invoke and wait for full response - resp, err := svc.InvokeModel(context.TODO(), &bedrockruntime.InvokeModelInput{ - Accept: &accept, - ModelId: &m.ModelID, - ContentType: &contentType, - Body: bodyString, - }) + output, err := svc.Converse(context.TODO(), converseInput) if err != nil { log.Fatalf("error from Bedrock, %v", err) } - // print response - switch m.ModelFamily { - case "claude3": - var out providers.AnthropicClaudeMessagesInvokeModelOutput - - err = json.Unmarshal(resp.Body, &out) - if err != nil { - log.Fatalf("unable to unmarshal response from Bedrock: %v", err) - } - fmt.Println(out.Content[0].Text) - case "claude": - var out providers.AnthropicClaudeInvokeModelOutput - - err = json.Unmarshal(resp.Body, &out) - if err != nil { - log.Fatalf("unable to unmarshal response from Bedrock: %v", err) - } - fmt.Println(out.Completion) - case "jurassic": - var out providers.AI21LabsJurrasicInvokeModelOutput - - err = json.Unmarshal(resp.Body, &out) - if err != nil { - log.Fatalf("unable to unmarshal response from Bedrock: %v", err) - } - fmt.Println(out.Completions[0].Data.Text) - case "command": - var out providers.CohereCommandInvokeModelOutput - - err = json.Unmarshal(resp.Body, &out) - if err != nil { - log.Fatalf("unable to unmarshal response from Bedrock: %v", err) - } - fmt.Println(out.Generations[0].Text) - case "llama": - var out providers.MetaLlamaInvokeModelOutput - - err = json.Unmarshal(resp.Body, &out) - if err != nil { - log.Fatalf("unable to unmarshal response from Bedrock: %v", err) - } - fmt.Println(out.Generation) - case "titan": - var out providers.AmazonTitanTextInvokeModelOutput - - err = json.Unmarshal(resp.Body, &out) - if err != nil { - log.Fatalf("unable to unmarshal response from Bedrock: %v", err) - } - fmt.Println(out.Results[0].OutputText) - default: - log.Fatalf("invalid model: %s", m.ModelID) - } + reponse, _ := output.Output.(*types.ConverseOutputMemberMessage) + responseContentBlock := reponse.Value.Content[0] + text, _ := responseContentBlock.(*types.ContentBlockMemberText) + + fmt.Println(text.Value) + } else { + converseStreamInput := &bedrockruntime.ConverseStreamInput{ + ModelId: &modelIdString, + InferenceConfig: &conf, + } + converseStreamInput.Messages = append(converseStreamInput.Messages, userMsg) + // invoke with streaming response - resp, err := svc.InvokeModelWithResponseStream(context.TODO(), &bedrockruntime.InvokeModelWithResponseStreamInput{ - Accept: &accept, - ModelId: &m.ModelID, - ContentType: &contentType, - Body: bodyString, - }) + output, err := svc.ConverseStream(context.Background(), converseStreamInput) if err != nil { log.Fatalf("error from Bedrock, %v", err) } - // print streaming response - switch m.ModelFamily { - case "claude3": - var out providers.AnthropicClaudeMessagesInvokeModelOutput - - stream := resp.GetStream().Reader - events := stream.Events() - - for { - event := <-events - if event != nil { - if v, ok := event.(*types.ResponseStreamMemberChunk); ok { - // v has fields - err := json.Unmarshal([]byte(v.Value.Bytes), &out) - if err != nil { - log.Printf("unable to decode response:, %v", err) - continue - } - if out.Type == "content_block_delta" { - fmt.Printf("%v", out.Delta.Text) - } - } else if v, ok := event.(*types.UnknownUnionMember); ok { - // catchall - fmt.Print(v.Value) - } - } else { - break - } - } - stream.Close() - - if stream.Err() != nil { - log.Fatalf("error from Bedrock, %v", stream.Err()) - } - fmt.Println() - case "claude": - var out providers.AnthropicClaudeInvokeModelOutput - - stream := resp.GetStream().Reader - events := stream.Events() - - for { - event := <-events - if event != nil { - if v, ok := event.(*types.ResponseStreamMemberChunk); ok { - // v has fields - err := json.Unmarshal([]byte(v.Value.Bytes), &out) - if err != nil { - log.Printf("unable to decode response:, %v", err) - continue - } - fmt.Printf("%v", out.Completion) - } else if v, ok := event.(*types.UnknownUnionMember); ok { - // catchall - fmt.Print(v.Value) - } - } else { - break - } - } - stream.Close() - - if stream.Err() != nil { - log.Fatalf("error from Bedrock, %v", stream.Err()) - } - fmt.Println() - case "command": - var out providers.CohereCommandInvokeModelOutput - - stream := resp.GetStream().Reader - events := stream.Events() - - for { - event := <-events - if event != nil { - if v, ok := event.(*types.ResponseStreamMemberChunk); ok { - // v has fields - err := json.Unmarshal([]byte(v.Value.Bytes), &out) - if err != nil { - log.Printf("unable to decode response:, %v", err) - continue - } - fmt.Printf("%v", out.Generations[0].Text) - } else if v, ok := event.(*types.UnknownUnionMember); ok { - // catchall - fmt.Print(v.Value) - } - } else { - break - } - } - stream.Close() - - if stream.Err() != nil { - log.Fatalf("error from Bedrock, %v", stream.Err()) - } - fmt.Println() - case "llama": - var out providers.MetaLlamaInvokeModelOutput - - stream := resp.GetStream().Reader - events := stream.Events() - - for { - event := <-events - if event != nil { - if v, ok := event.(*types.ResponseStreamMemberChunk); ok { - // v has fields - err := json.Unmarshal([]byte(v.Value.Bytes), &out) - if err != nil { - log.Printf("unable to decode response:, %v", err) - continue - } - fmt.Printf("%v", out.Generation) - } else if v, ok := event.(*types.UnknownUnionMember); ok { - // catchall - fmt.Print(v.Value) - } - } else { - break - } - } - stream.Close() - - if stream.Err() != nil { - log.Fatalf("error from Bedrock, %v", stream.Err()) - } - fmt.Println() - default: - log.Fatalf("invalid model: %s", m.ModelID) + _, err = utils.ProcessStreamingOutput(output, func(ctx context.Context, part string) error { + fmt.Print(part) + return nil + }) + if err != nil { + log.Fatal("streaming output processing error: ", err) } + fmt.Println() } }, } -func readImage(filename string) (string, string, error) { - - data, err := os.ReadFile(filename) - if err != nil { - return "", "", err - } - - //var base64Encoding string - - // Determine the content type of the image file - mimeType := http.DetectContentType(data) - - switch mimeType { - case "image/png": - fmt.Println() - case "image/jpeg": - fmt.Println() - img, err := jpeg.Decode(bytes.NewReader(data)) - if err != nil { - return "", "", fmt.Errorf("unable to decode jpeg: %w", err) - } - - var buf bytes.Buffer - if err := png.Encode(&buf, img); err != nil { - return "", "", fmt.Errorf("unable to encode png: %w", err) - } - data = buf.Bytes() - default: - return "", "", fmt.Errorf("unsupported content typo: %s", mimeType) - } - - imgBase64Str := base64.StdEncoding.EncodeToString(data) - //r //eturn hdr.Filename, imgBase64Str, nil - - // Print the full base64 representation of the image - return imgBase64Str, mimeType, nil -} - func init() { rootCmd.AddCommand(promptCmd) - - // Here you will define your flags and configuration settings. - - // Cobra supports Persistent Flags which will work for this command - // and all subcommands, e.g.: - promptCmd.PersistentFlags().StringP("model-id", "m", "anthropic.claude-3-haiku-20240307-v1:0", "set the model id") + promptCmd.PersistentFlags().StringP("model-id", "m", "anthropic.claude-3-5-sonnet-20240620-v1:0", "set the model id") + promptCmd.PersistentFlags().String("custom-arn", "", "pass a custom arn from bedrock marketplace or cross-region inference") promptCmd.PersistentFlags().StringP("image", "i", "", "path to image") promptCmd.PersistentFlags().Bool("no-stream", false, "return the full response once it has completed") - promptCmd.PersistentFlags().Float64("temperature", 1, "temperature setting") - promptCmd.PersistentFlags().Float64("topP", 0.999, "topP setting") - promptCmd.PersistentFlags().Float64("topK", 250, "topK setting") - promptCmd.PersistentFlags().Int("max-tokens", 500, "max tokens to sample") - - // Cobra supports local flags which will only run when this command - // is called directly, e.g.: - // promptCmd.Flags().BoolP("toggle", "t", false, "Help message for toggle") + promptCmd.PersistentFlags().Float32("temperature", 1.0, "temperature setting") + promptCmd.PersistentFlags().Float32("topP", 0.999, "topP setting") + promptCmd.PersistentFlags().Int32("max-tokens", 500, "max tokens") } diff --git a/cmd/root.go b/cmd/root.go index 3cf3979..685e9b3 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -1,5 +1,5 @@ /* -Copyright © 2024 NAME HERE +Copyright © 2024 Micah Walter */ package cmd @@ -13,11 +13,16 @@ import ( var rootCmd = &cobra.Command{ Use: "chat-cli", Short: "Chat with LLMs from Amazon Bedrock!", - Long: `This is a command line tool that allows you to chat with LLMs from Amazon Bedrock!`, - - // Uncomment the following line if your bare application - // has an action associated with it: - // Run: func(cmd *cobra.Command, args []string) { }, + Long: `Chat-CLI is a command line tool that allows you to chat with LLMs from Amazon Bedrock! + + ██████╗██╗ ██╗ █████╗ ████████╗ ██████╗██╗ ██╗ +██╔════╝██║ ██║██╔══██╗╚══██╔══╝ ██╔════╝██║ ██║ +██║ ███████║███████║ ██║ ██║ ██║ ██║ +██║ ██╔══██║██╔══██║ ██║ ██║ ██║ ██║ +╚██████╗██║ ██║██║ ██║ ██║███████╗╚██████╗███████╗██║ + ╚═════╝╚═╝ ╚═╝╚═╝ ╚═╝ ╚═╝╚══════╝ ╚═════╝╚══════╝╚═╝ + + `, } // Execute adds all child commands to the root command and sets flags appropriately. @@ -30,14 +35,5 @@ func Execute() { } func init() { - // Here you will define your flags and configuration settings. - // Cobra supports persistent flags, which, if defined here, - // will be global for your application. - - //rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.chat-cli.yaml)") rootCmd.PersistentFlags().StringP("region", "r", "us-east-1", "set the AWS region") - - // Cobra also supports local flags, which will only run - // when this action is called directly. - // rootCmd.Flags().BoolP("toggle", "t", false, "Help message for toggle") } diff --git a/cmd/version.go b/cmd/version.go new file mode 100644 index 0000000..54c9e62 --- /dev/null +++ b/cmd/version.go @@ -0,0 +1,29 @@ +/* +Copyright © 2024 Micah Walter +*/ +package cmd + +import ( + "fmt" + "runtime" + + "github.com/spf13/cobra" +) + +// versionCmd represents the version command +var versionCmd = &cobra.Command{ + Use: "version", + Short: "Prints the current version", + Long: `Prints the current version`, + Run: func(cmd *cobra.Command, args []string) { + // until there is a better way to do this + v := "v0.4.5" + o := runtime.GOOS + a := runtime.GOARCH + fmt.Printf("chat-cli %s, %s/%s\n", v, o, a) + }, +} + +func init() { + rootCmd.AddCommand(versionCmd) +} diff --git a/config/config.go b/config/config.go new file mode 100644 index 0000000..20b35bd --- /dev/null +++ b/config/config.go @@ -0,0 +1,157 @@ +package config + +import ( + "os" + "path/filepath" + "runtime" + + "github.com/spf13/viper" +) + +// FileManager handles OS-specific paths for configuration and data storage +type FileManager struct { + AppName string + ConfigFile string + DBFile string + ConfigPath string + DataPath string + Environment string +} + +// NewFileManager creates a new instance of FileManager with OS-specific paths +func NewFileManager(appName string) (*FileManager, error) { + fm := &FileManager{ + AppName: appName, + ConfigFile: "config.yaml", + DBFile: "data.db", + Environment: os.Getenv("APP_ENV"), + } + + // Set default environment if not specified + if fm.Environment == "" { + fm.Environment = "development" + } + + // Initialize paths based on OS + if err := fm.initializePaths(); err != nil { + return nil, err + } + + return fm, nil +} + +// initializePaths sets up OS-specific paths for config and data storage +func (fm *FileManager) initializePaths() error { + var configBase string + var dataBase string + + switch runtime.GOOS { + case "windows": + appData := os.Getenv("APPDATA") + if appData == "" { + appData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Roaming") + } + configBase = appData + dataBase = appData + + case "darwin": + home := os.Getenv("HOME") + configBase = filepath.Join(home, "Library", "Application Support") + dataBase = configBase + + default: // Linux and other Unix-like systems + // Follow XDG Base Directory Specification + xdgConfig := os.Getenv("XDG_CONFIG_HOME") + if xdgConfig == "" { + xdgConfig = filepath.Join(os.Getenv("HOME"), ".config") + } + xdgData := os.Getenv("XDG_DATA_HOME") + if xdgData == "" { + xdgData = filepath.Join(os.Getenv("HOME"), ".local", "share") + } + configBase = xdgConfig + dataBase = xdgData + } + + // Set final paths + fm.ConfigPath = filepath.Join(configBase, fm.AppName) + fm.DataPath = filepath.Join(dataBase, fm.AppName) + + // Create directories if they don't exist + if err := os.MkdirAll(fm.ConfigPath, 0755); err != nil { + return err + } + if err := os.MkdirAll(fm.DataPath, 0755); err != nil { + return err + } + + return nil +} + +// InitializeViper sets up Viper with the correct config file path +func (fm *FileManager) InitializeViper() error { + viper.SetConfigName(fm.ConfigFile[:len(fm.ConfigFile)-len(filepath.Ext(fm.ConfigFile))]) + viper.SetConfigType("yaml") + viper.AddConfigPath(fm.ConfigPath) + + // Set some default configurations + viper.SetDefault("environment", fm.Environment) + viper.SetDefault("db_path", fm.GetDBPath()) + viper.SetDefault("db_driver", "sqlite3") + + // Create config file if it doesn't exist + if err := fm.createDefaultConfig(); err != nil { + return err + } + + return viper.ReadInConfig() +} + +// GetDBPath returns the full path to the SQLite database file +func (fm *FileManager) GetDBPath() string { + return filepath.Join(fm.DataPath, fm.DBFile) +} + +// GetDBDriver returns the database type from the config +func (fm *FileManager) GetDBDriver() string { + return viper.GetString("db_driver") +} + +// createDefaultConfig creates a default config file if it doesn't exist +func (fm *FileManager) createDefaultConfig() error { + configPath := filepath.Join(fm.ConfigPath, fm.ConfigFile) + if _, err := os.Stat(configPath); os.IsNotExist(err) { + return viper.SafeWriteConfig() + } + return nil +} + +// GetConfigValue returns a configuration value with precedence order: +// 1. Feature flag (command line argument) +// 2. Configuration file +// 3. Default value +func (fm *FileManager) GetConfigValue(key string, flagValue interface{}, defaultValue interface{}) interface{} { + // Check if flag value is provided and not empty/zero value + switch v := flagValue.(type) { + case string: + if v != "" && v != defaultValue { + return v + } + case int32: + if v != 0 && v != defaultValue { + return v + } + case float32: + if v != 0.0 && v != defaultValue { + return v + } + } + + // Check configuration file + if viper.IsSet(key) { + return viper.Get(key) + } + + // Return default value + return defaultValue +} diff --git a/db/db.go b/db/db.go new file mode 100644 index 0000000..26b2ef3 --- /dev/null +++ b/db/db.go @@ -0,0 +1,21 @@ +package db + +import "database/sql" + +// Database represents a generic database connection +type Database interface { + GetDB() *sql.DB + Connect() error + Close() error + Migrate() error +} + +// Config holds common database configuration +type Config struct { + Driver string + Host string + Port int + Name string + Username string + Password string +} diff --git a/db/migrations.go b/db/migrations.go new file mode 100644 index 0000000..c0716c7 --- /dev/null +++ b/db/migrations.go @@ -0,0 +1,10 @@ +// db/migrations.go +package db + +// Migration defines what any database migration must implement +type Migration interface { + // MigrateUp creates or updates database schema + MigrateUp() error + // MigrateDown rolls back database changes (useful for testing) + MigrateDown() error +} diff --git a/db/sqlite/migrations.go b/db/sqlite/migrations.go new file mode 100644 index 0000000..668dc1f --- /dev/null +++ b/db/sqlite/migrations.go @@ -0,0 +1,59 @@ +// db/sqlite/migrations.go +package sqlite + +import ( + "database/sql" + "fmt" +) + +type SQLiteMigration struct { + db *sql.DB +} + +func NewSQLiteMigration(db *sql.DB) *SQLiteMigration { + return &SQLiteMigration{db: db} +} + +func (m *SQLiteMigration) MigrateUp() error { + + var err error + + chatsTable := ` + CREATE TABLE IF NOT EXISTS chats ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + chat_id TEXT NOT NULL, + persona TEXT NOT NULL, + message TEXT NOT NULL, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP + ); + + -- Create trigger to update the updated_at timestamp + CREATE TRIGGER IF NOT EXISTS chats_updated_at + AFTER UPDATE ON chats + BEGIN + UPDATE chats SET updated_at = CURRENT_TIMESTAMP + WHERE id = NEW.id; + END;` + + _, err = m.db.Exec(chatsTable) + if err != nil { + return fmt.Errorf("error creating users table: %v", err) + } + + return nil +} + +func (m *SQLiteMigration) MigrateDown() error { + // Drop the users table and its trigger + dropTables := ` + DROP TRIGGER IF EXISTS chats_updated_at; + DROP TABLE IF EXISTS chats;` + + _, err := m.db.Exec(dropTables) + if err != nil { + return fmt.Errorf("error dropping tables: %v", err) + } + + return nil +} diff --git a/db/sqlite/sqlite.go b/db/sqlite/sqlite.go new file mode 100644 index 0000000..8a4a45a --- /dev/null +++ b/db/sqlite/sqlite.go @@ -0,0 +1,42 @@ +// db/sqlite/sqlite.go +package sqlite + +import ( + "database/sql" + "fmt" + + "github.com/chat-cli/chat-cli/db" + + _ "github.com/mattn/go-sqlite3" +) + +type SQLiteDB struct { + db *sql.DB + config db.Config +} + +func (s *SQLiteDB) Migrate() error { + migration := NewSQLiteMigration(s.db) + return migration.MigrateUp() +} + +func NewSQLiteDB(config db.Config) *SQLiteDB { + return &SQLiteDB{config: config} +} + +func (s *SQLiteDB) Connect() error { + db, err := sql.Open("sqlite3", s.config.Name) + if err != nil { + return fmt.Errorf("sqlite connection error: %v", err) + } + s.db = db + return nil +} + +func (s *SQLiteDB) GetDB() *sql.DB { + return s.db +} + +func (s *SQLiteDB) Close() error { + return s.db.Close() +} diff --git a/docs/.gitignore b/docs/.gitignore new file mode 100644 index 0000000..49914f2 --- /dev/null +++ b/docs/.gitignore @@ -0,0 +1,3 @@ +_build/* +_static/* +venv/* \ No newline at end of file diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 0000000..d4bb2cb --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = . +BUILDDIR = _build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 0000000..7e49fc4 --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,29 @@ +# Configuration file for the Sphinx documentation builder. +# +# For the full list of built-in configuration values, see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Project information ----------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information + +project = 'Chat-CLI' +copyright = '2024, Micah Walter' +author = 'Micah Walter' +release = 'v0.4.5' +version = 'v0.4' + +# -- General configuration --------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration + +extensions = ['myst_parser'] + +templates_path = ['_templates'] +exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] + + + +# -- Options for HTML output ------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output + +html_theme = 'furo' +html_static_path = ['_static'] diff --git a/docs/images/index-01.png b/docs/images/index-01.png new file mode 100644 index 0000000..9917e07 Binary files /dev/null and b/docs/images/index-01.png differ diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 0000000..0aa4204 --- /dev/null +++ b/docs/index.md @@ -0,0 +1,58 @@ +# Chat-CLI + +Chat-CLI is a little terminal based program that lets you interact with [LLM](#models)s available via [Amazon Bedrock](https://aws.amazon.com/bedrock). + +![Chat-CLI in action!](images/index-01.png) + +## Quick start + +Using [Homebrew](https://brew.sh/) do this: + +```shell + brew tap chat-cli/chat-cli + brew install chat-cli +``` + +If you have an [AWS account](#prereqs), and you have [enabled model access](#prereqs) for the LLMs you wish to use, you can do the following: + +Using the [prompt](#prompt) command, you can send one liner prompts with attachments to any text to text LLM like Anthropic's Claude Sonnet 3.5, Meta's Llama 3.2, or Amazon Nova Pro + +```shell + # set up your AWS credentials on your machine using the AWS CLI + aws configure + + # send a prompt to Anthropic Claude Sonnet 3.5 + chat-cli prompt "What is AWS?" + + # read contents of a file to Chat-CLI via stdin + cat your-file.go | chat-cli prompt "explain this code" + + # attach an image for models that support vision like Anthropic Claude Sonnet 3.5 + chat-cli prompt "describe this image" --image myfile.png +``` + +With the [chat](#chat) command, you can start an interactive chat session with any text to text LLM + +```shell + # start an interactive chat session using Amazon Nova Micro + chat-cli chat +``` + +With the [image](#image) command, you can generate images with any text to image LLM like Amazon Nova Canvas or Stability AI's Stable Diffusion 3 + +```shell + # generate an image from text using Amazon Nova Canvas + chat-cli image "generate an image of a cat driving a car" +``` + +## Contents + +```{toctree} +--- +maxdepth: 3 +--- +setup +usage +models +marketplace +``` \ No newline at end of file diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 0000000..32bb245 --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=. +set BUILDDIR=_build + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.https://www.sphinx-doc.org/ + exit /b 1 +) + +if "%1" == "" goto help + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/docs/marketplace.md b/docs/marketplace.md new file mode 100644 index 0000000..b088cbb --- /dev/null +++ b/docs/marketplace.md @@ -0,0 +1,3 @@ +# Model Marketplace + +There are even more models \ No newline at end of file diff --git a/docs/models.md b/docs/models.md new file mode 100644 index 0000000..e4d1ef9 --- /dev/null +++ b/docs/models.md @@ -0,0 +1,3 @@ +# Models + +There are many models. \ No newline at end of file diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 0000000..acd729c --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,30 @@ +alabaster==1.0.0 +babel==2.16.0 +beautifulsoup4==4.12.3 +certifi==2024.12.14 +charset-normalizer==3.4.0 +docutils==0.21.2 +furo==2024.8.6 +idna==3.10 +imagesize==1.4.1 +Jinja2==3.1.4 +markdown-it-py==3.0.0 +MarkupSafe==3.0.2 +mdit-py-plugins==0.4.2 +mdurl==0.1.2 +myst-parser==4.0.0 +packaging==24.2 +Pygments==2.18.0 +PyYAML==6.0.2 +requests==2.32.3 +snowballstemmer==2.2.0 +soupsieve==2.6 +Sphinx==8.1.3 +sphinx-basic-ng==1.0.0b2 +sphinxcontrib-applehelp==2.0.0 +sphinxcontrib-devhelp==2.0.0 +sphinxcontrib-htmlhelp==2.1.0 +sphinxcontrib-jsmath==1.0.1 +sphinxcontrib-qthelp==2.0.0 +sphinxcontrib-serializinghtml==2.0.0 +urllib3==2.2.3 diff --git a/docs/setup.md b/docs/setup.md new file mode 100644 index 0000000..58cd2f7 --- /dev/null +++ b/docs/setup.md @@ -0,0 +1,38 @@ +# Setup + +(prereqs)= +## Prerequisites + +You will need an AWS account with programatic access keys configured on your system. These keys will need to have access to Amazon Bedrock. Additionally, you will need to enable model access for all of the LLMs you wish to use in Amazon Bedrock's Model Access page. Here is a summary of steps to get things set up correctly. + +### Creating an AWS Account + +If you don't already have an AWS Account, you will need to create one. Simply go [here](https://portal.aws.amazon.com/billing/signup) and follow the steps to create an AWS Account from scratch. Please note, you will need to provide a valid credit card number, and there are fees associated with using LLMs with Amazon Bedrock, [outlined here](https://aws.amazon.com/bedrock/pricing/). + + + +### Configuring IAM user access to Amazon Bedrock + +### Configuring Amaozn Bedrock Model Access + +### Configuring AWS credentials on your local machine + +```shell + aws configure +``` + +## Installing Chat-CLI + +### Install using Homebrew + +### Install using pre-pachaged binaries + +### Install from source + +## Confirm your installation + +Once you have succesfully installed Chat-CLI, you should be able to verify your installation with the following command. + +```shell + chat-cli version +``` \ No newline at end of file diff --git a/docs/usage.md b/docs/usage.md new file mode 100644 index 0000000..67027d6 --- /dev/null +++ b/docs/usage.md @@ -0,0 +1,93 @@ +# Usage + +(config)= +## Config + +Chat-CLI provides a configuration system that allows you to set persistent default values for commonly used settings. This eliminates the need to specify the same flags repeatedly when using the `chat` and `prompt` commands. + +### Managing Configuration + +#### Setting Values + +Use the `config set` command to store default values: + +```shell +# Set a default model ID +chat-cli config set model-id "anthropic.claude-3-5-sonnet-20240620-v1:0" + +# Set a custom ARN for marketplace or cross-region models +chat-cli config set custom-arn "arn:aws:bedrock:us-west-2::foundation-model/custom-model" +``` + +#### Viewing Configuration + +List all current configuration values: + +```shell +chat-cli config list +``` + +Example output: +``` +Current configuration: + model-id = anthropic.claude-3-5-sonnet-20240620-v1:0 + custom-arn = arn:aws:bedrock:us-west-2::foundation-model/custom-model +``` + +#### Removing Values + +Remove specific configuration values when no longer needed: + +```shell +chat-cli config unset model-id +chat-cli config unset custom-arn +``` + +### Configuration Precedence + +The configuration system uses a clear precedence hierarchy to determine which values to use: + +1. **Command line flags** (highest priority) + - Values specified with `--model-id` or `--custom-arn` flags + - Always override configuration file and defaults + +2. **Configuration file** (medium priority) + - Values set using `chat-cli config set` + - Used when no command line flag is provided + +3. **Built-in defaults** (lowest priority) + - Default model: `anthropic.claude-3-5-sonnet-20240620-v1:0` + - Used when no configuration or flags are set + +### Custom ARN Priority + +When both `model-id` and `custom-arn` are configured, `custom-arn` takes precedence. This design allows you to: + +- Set a default `model-id` for regular use +- Override with `custom-arn` for marketplace or cross-region models +- Use command line flags to override either setting temporarily + +### Supported Settings + +| Setting | Description | Example | +|---------|-------------|---------| +| `model-id` | Default model identifier for Bedrock foundation models | `anthropic.claude-3-5-sonnet-20240620-v1:0` | +| `custom-arn` | Custom ARN for marketplace or cross-region inference | `arn:aws:bedrock:us-west-2::foundation-model/custom-model` | + +### Configuration Storage + +Configuration values are stored in a YAML file in your system's standard configuration directory: + +- **macOS**: `~/Library/Application Support/chat-cli/config.yaml` +- **Linux**: `~/.config/chat-cli/config.yaml` +- **Windows**: `%APPDATA%\chat-cli\config.yaml` + +(prompt)= +## Prompt + +(chat)= +## Chat + +(image)= +## Image + diff --git a/factory/database.go b/factory/database.go new file mode 100644 index 0000000..ece7064 --- /dev/null +++ b/factory/database.go @@ -0,0 +1,22 @@ +package factory + +import ( + "fmt" + + "github.com/chat-cli/chat-cli/db" + "github.com/chat-cli/chat-cli/db/sqlite" +) + +// CreateDatabase is a factory function that returns the appropriate database implementation +func CreateDatabase(config db.Config) (db.Database, error) { + switch config.Driver { + case "sqlite3": + database := sqlite.NewSQLiteDB(config) + return database, database.Connect() + // case "postgres": + // database := postgres.NewPostgresDB(config) + // return database, database.Connect() + default: + return nil, fmt.Errorf("unsupported database driver: %s", config.Driver) + } +} diff --git a/go.mod b/go.mod index 1b68307..d491470 100644 --- a/go.mod +++ b/go.mod @@ -1,30 +1,51 @@ -module github.com/go-micah/chat-cli +module github.com/chat-cli/chat-cli -go 1.22.1 +go 1.23.4 require ( - github.com/aws/aws-sdk-go-v2/config v1.27.7 - github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.2 - github.com/go-micah/go-bedrock v0.1.10 + github.com/aws/aws-sdk-go-v2 v1.32.6 + github.com/aws/aws-sdk-go-v2/config v1.28.6 + github.com/aws/aws-sdk-go-v2/service/bedrock v1.25.0 + github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.23.0 + github.com/go-micah/go-bedrock v0.2.0 github.com/mattn/go-isatty v0.0.20 - github.com/spf13/cobra v1.8.0 + github.com/mattn/go-sqlite3 v1.14.24 + github.com/satori/go.uuid v1.2.0 + github.com/spf13/cobra v1.8.1 + github.com/spf13/viper v1.19.0 + gopkg.in/yaml.v3 v3.0.1 ) require ( - github.com/aws/aws-sdk-go-v2 v1.25.3 // indirect - github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.1 // indirect - github.com/aws/aws-sdk-go-v2/credentials v1.17.7 // indirect - github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.15.3 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.3 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.3 // indirect - github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.1 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.5 // indirect - github.com/aws/aws-sdk-go-v2/service/sso v1.20.2 // indirect - github.com/aws/aws-sdk-go-v2/service/ssooidc v1.23.2 // indirect - github.com/aws/aws-sdk-go-v2/service/sts v1.28.4 // indirect - github.com/aws/smithy-go v1.20.1 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.7 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.17.47 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.21 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.25 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.25 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.6 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.24.7 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.6 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.33.2 // indirect + github.com/aws/smithy-go v1.22.1 // indirect + github.com/fsnotify/fsnotify v1.7.0 // indirect + github.com/hashicorp/hcl v1.0.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/magiconair/properties v1.8.7 // indirect + github.com/mitchellh/mapstructure v1.5.0 // indirect + github.com/pelletier/go-toml/v2 v2.2.2 // indirect + github.com/sagikazarmark/locafero v0.4.0 // indirect + github.com/sagikazarmark/slog-shim v0.1.0 // indirect + github.com/sourcegraph/conc v0.3.0 // indirect + github.com/spf13/afero v1.11.0 // indirect + github.com/spf13/cast v1.6.0 // indirect github.com/spf13/pflag v1.0.5 // indirect - golang.org/x/sys v0.6.0 // indirect + github.com/subosito/gotenv v1.6.0 // indirect + go.uber.org/atomic v1.9.0 // indirect + go.uber.org/multierr v1.9.0 // indirect + golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect + golang.org/x/sys v0.24.0 // indirect + golang.org/x/text v0.14.0 // indirect + gopkg.in/ini.v1 v1.67.0 // indirect ) diff --git a/go.sum b/go.sum index 6ae9488..a000c27 100644 --- a/go.sum +++ b/go.sum @@ -1,46 +1,118 @@ -github.com/aws/aws-sdk-go-v2 v1.25.3 h1:xYiLpZTQs1mzvz5PaI6uR0Wh57ippuEthxS4iK5v0n0= -github.com/aws/aws-sdk-go-v2 v1.25.3/go.mod h1:35hUlJVYd+M++iLI3ALmVwMOyRYMmRqUXpTtRGW+K9I= -github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.1 h1:gTK2uhtAPtFcdRRJilZPx8uJLL2J85xK11nKtWL0wfU= -github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.1/go.mod h1:sxpLb+nZk7tIfCWChfd+h4QwHNUR57d8hA1cleTkjJo= -github.com/aws/aws-sdk-go-v2/config v1.27.7 h1:JSfb5nOQF01iOgxFI5OIKWwDiEXWTyTgg1Mm1mHi0A4= -github.com/aws/aws-sdk-go-v2/config v1.27.7/go.mod h1:PH0/cNpoMO+B04qET699o5W92Ca79fVtbUnvMIZro4I= -github.com/aws/aws-sdk-go-v2/credentials v1.17.7 h1:WJd+ubWKoBeRh7A5iNMnxEOs982SyVKOJD+K8HIezu4= -github.com/aws/aws-sdk-go-v2/credentials v1.17.7/go.mod h1:UQi7LMR0Vhvs+44w5ec8Q+VS+cd10cjwgHwiVkE0YGU= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.15.3 h1:p+y7FvkK2dxS+FEwRIDHDe//ZX+jDhP8HHE50ppj4iI= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.15.3/go.mod h1:/fYB+FZbDlwlAiynK9KDXlzZl3ANI9JkD0Uhz5FjNT4= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.3 h1:ifbIbHZyGl1alsAhPIYsHOg5MuApgqOvVeI8wIugXfs= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.3/go.mod h1:oQZXg3c6SNeY6OZrDY+xHcF4VGIEoNotX2B4PrDeoJI= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.3 h1:Qvodo9gHG9F3E8SfYOspPeBt0bjSbsevK8WhRAUHcoY= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.3/go.mod h1:vCKrdLXtybdf/uQd/YfVR2r5pcbNuEYKzMQpcxmeSJw= -github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 h1:hT8rVHwugYE2lEfdFE0QWVo81lF7jMrYJVDWI+f+VxU= -github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0/go.mod h1:8tu/lYfQfFe6IGnaOdrpVgEL2IrrDOf6/m9RQum4NkY= -github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.2 h1:J+a+y3Q8S021CPjpV4hmykYZ9rj6hz3i5CdjwS6WF40= -github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.2/go.mod h1:GodCGWC354HSLDs4LDVSa9wIO5n4/gusww/e2duy/84= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.1 h1:EyBZibRTVAs6ECHZOw5/wlylS9OcTzwyjeQMudmREjE= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.1/go.mod h1:JKpmtYhhPs7D97NL/ltqz7yCkERFW5dOlHyVl66ZYF8= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.5 h1:K/NXvIftOlX+oGgWGIa3jDyYLDNsdVhsjHmsBH2GLAQ= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.5/go.mod h1:cl9HGLV66EnCmMNzq4sYOti+/xo8w34CsgzVtm2GgsY= -github.com/aws/aws-sdk-go-v2/service/sso v1.20.2 h1:XOPfar83RIRPEzfihnp+U6udOveKZJvPQ76SKWrLRHc= -github.com/aws/aws-sdk-go-v2/service/sso v1.20.2/go.mod h1:Vv9Xyk1KMHXrR3vNQe8W5LMFdTjSeWk0gBZBzvf3Qa0= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.23.2 h1:pi0Skl6mNl2w8qWZXcdOyg197Zsf4G97U7Sso9JXGZE= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.23.2/go.mod h1:JYzLoEVeLXk+L4tn1+rrkfhkxl6mLDEVaDSvGq9og90= -github.com/aws/aws-sdk-go-v2/service/sts v1.28.4 h1:Ppup1nVNAOWbBOrcoOxaxPeEnSFB2RnnQdguhXpmeQk= -github.com/aws/aws-sdk-go-v2/service/sts v1.28.4/go.mod h1:+K1rNPVyGxkRuv9NNiaZ4YhBFuyw2MMA9SlIJ1Zlpz8= -github.com/aws/smithy-go v1.20.1 h1:4SZlSlMr36UEqC7XOyRVb27XMeZubNcBNN+9IgEPIQw= -github.com/aws/smithy-go v1.20.1/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E= -github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= -github.com/go-micah/go-bedrock v0.1.10 h1:Kz7euvlMspZNJVAkrP2cx5cRyQbnadJ9ibiXMOW8m04= -github.com/go-micah/go-bedrock v0.1.10/go.mod h1:2h5MwPzG4zDkBxugMQrAvwAALw6ezefrVh+h9tI9Vek= +github.com/aws/aws-sdk-go-v2 v1.32.6 h1:7BokKRgRPuGmKkFMhEg/jSul+tB9VvXhcViILtfG8b4= +github.com/aws/aws-sdk-go-v2 v1.32.6/go.mod h1:P5WJBrYqqbWVaOxgH0X/FYYD47/nooaPOZPlQdmiN2U= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.7 h1:lL7IfaFzngfx0ZwUGOZdsFFnQ5uLvR0hWqqhyE7Q9M8= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.7/go.mod h1:QraP0UcVlQJsmHfioCrveWOC1nbiWUl3ej08h4mXWoc= +github.com/aws/aws-sdk-go-v2/config v1.28.6 h1:D89IKtGrs/I3QXOLNTH93NJYtDhm8SYa9Q5CsPShmyo= +github.com/aws/aws-sdk-go-v2/config v1.28.6/go.mod h1:GDzxJ5wyyFSCoLkS+UhGB0dArhb9mI+Co4dHtoTxbko= +github.com/aws/aws-sdk-go-v2/credentials v1.17.47 h1:48bA+3/fCdi2yAwVt+3COvmatZ6jUDNkDTIsqDiMUdw= +github.com/aws/aws-sdk-go-v2/credentials v1.17.47/go.mod h1:+KdckOejLW3Ks3b0E3b5rHsr2f9yuORBum0WPnE5o5w= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.21 h1:AmoU1pziydclFT/xRV+xXE/Vb8fttJCLRPv8oAkprc0= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.21/go.mod h1:AjUdLYe4Tgs6kpH4Bv7uMZo7pottoyHMn4eTcIcneaY= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.25 h1:s/fF4+yDQDoElYhfIVvSNyeCydfbuTKzhxSXDXCPasU= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.25/go.mod h1:IgPfDv5jqFIzQSNbUEMoitNooSMXjRSDkhXv8jiROvU= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.25 h1:ZntTCl5EsYnhN/IygQEUugpdwbhdkom9uHcbCftiGgA= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.25/go.mod h1:DBdPrgeocww+CSl1C8cEV8PN1mHMBhuCDLpXezyvWkE= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 h1:VaRN3TlFdd6KxX1x3ILT5ynH6HvKgqdiXoTxAF4HQcQ= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1/go.mod h1:FbtygfRFze9usAadmnGJNc8KsP346kEe+y2/oyhGAGc= +github.com/aws/aws-sdk-go-v2/service/bedrock v1.25.0 h1:n2mFFkxqCnzFCf0T9uTbVkNM6i90Fx34ggvcs1DzgOc= +github.com/aws/aws-sdk-go-v2/service/bedrock v1.25.0/go.mod h1:BKSewSMuaeUidKqXArDlT06PWK/PP3wsgLWTXKeKgQw= +github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.23.0 h1:mfV5tcLXeRLbiyI4EHoHWH1sIU7JvbfXVvymUCIgZEo= +github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.23.0/go.mod h1:YSSgYnasDKm5OjU3bOPkaz+2PFO6WjEQGIA6KQNsR3Q= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1 h1:iXtILhvDxB6kPvEXgsDhGaZCSC6LQET5ZHSdJozeI0Y= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1/go.mod h1:9nu0fVANtYiAePIBh2/pFUSwtJ402hLnp854CNoDOeE= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.6 h1:50+XsN70RS7dwJ2CkVNXzj7U2L1HKP8nqTd3XWEXBN4= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.6/go.mod h1:WqgLmwY7so32kG01zD8CPTJWVWM+TzJoOVHwTg4aPug= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.7 h1:rLnYAfXQ3YAccocshIH5mzNNwZBkBo+bP6EhIxak6Hw= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.7/go.mod h1:ZHtuQJ6t9A/+YDuxOLnbryAmITtr8UysSny3qcyvJTc= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.6 h1:JnhTZR3PiYDNKlXy50/pNeix9aGMo6lLpXwJ1mw8MD4= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.6/go.mod h1:URronUEGfXZN1VpdktPSD1EkAL9mfrV+2F4sjH38qOY= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.2 h1:s4074ZO1Hk8qv65GqNXqDjmkf4HSQqJukaLuuW0TpDA= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.2/go.mod h1:mVggCnIWoM09jP71Wh+ea7+5gAp53q+49wDFs1SW5z8= +github.com/aws/smithy-go v1.22.1 h1:/HPHZQ0g7f4eUeK6HKglFz8uwVfZKgoI25rb/J+dnro= +github.com/aws/smithy-go v1.22.1/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= +github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= +github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= +github.com/go-micah/go-bedrock v0.2.0 h1:eWl/g7BDOmfw8W+ULGSc/07I5H1bzbslixjRHtasDbQ= +github.com/go-micah/go-bedrock v0.2.0/go.mod h1:2h5MwPzG4zDkBxugMQrAvwAALw6ezefrVh+h9tI9Vek= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= +github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= +github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM= +github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= +github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= -github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0= -github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho= +github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ= +github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4= +github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE= +github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ= +github.com/satori/go.uuid v1.2.0 h1:0uYX9dsZ2yD7q2RtLRtPSdGDWzjeM3TbMJP9utgA0ww= +github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= +github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= +github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= +github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= +github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= +github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= +github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= +github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= +github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= -golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ= +github.com/spf13/viper v1.19.0 h1:RWq5SEjt8o25SROyN3z2OrDB9l7RPd3lwTWU8EcEdcI= +github.com/spf13/viper v1.19.0/go.mod h1:GQUN9bilAbhU/jgc1bKs99f/suXKeUMct8Adx5+Ntkg= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= +github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= +go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= +go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= +golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g= +golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg= +golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= +gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/main.go b/main.go index 9ef9fb1..a246ddb 100644 --- a/main.go +++ b/main.go @@ -1,9 +1,9 @@ /* -Copyright © 2024 NAME HERE +Copyright © 2024 Micah Walter */ package main -import "github.com/go-micah/chat-cli/cmd" +import "github.com/chat-cli/chat-cli/cmd" func main() { cmd.Execute() diff --git a/models/models.go b/models/models.go deleted file mode 100644 index 9fcbe5f..0000000 --- a/models/models.go +++ /dev/null @@ -1,150 +0,0 @@ -package models - -import ( - "fmt" - "slices" -) - -type Model struct { - ModelID string - ModelFamily string - ModelType string - BaseModel bool - SupportsStreaming bool -} - -var models = []Model{ - { - ModelID: "anthropic.claude-3-sonnet-20240229-v1:0", - ModelFamily: "claude3", - ModelType: "text", - BaseModel: false, - SupportsStreaming: true, - }, - { - ModelID: "anthropic.claude-3-haiku-20240307-v1:0", - ModelFamily: "claude3", - ModelType: "text", - BaseModel: true, - SupportsStreaming: true, - }, - { - ModelID: "anthropic.claude-v2:1", - ModelFamily: "claude", - ModelType: "text", - BaseModel: false, - SupportsStreaming: true, - }, - { - ModelID: "anthropic.claude-v2", - ModelFamily: "claude", - ModelType: "text", - BaseModel: false, - SupportsStreaming: true, - }, - { - ModelID: "anthropic.claude-instant-v1", - ModelFamily: "claude", - ModelType: "text", - BaseModel: true, - SupportsStreaming: true, - }, - { - ModelID: "ai21.j2-mid-v1", - ModelFamily: "jurassic", - ModelType: "text", - BaseModel: true, - SupportsStreaming: false, - }, - { - ModelID: "ai21.j2-ultra-v1", - ModelFamily: "jurassic", - ModelType: "text", - BaseModel: false, - SupportsStreaming: false, - }, - { - ModelID: "cohere.command-light-text-v14", - ModelFamily: "command", - ModelType: "text", - BaseModel: true, - SupportsStreaming: true, - }, - { - ModelID: "cohere.command-text-v14", - ModelFamily: "command", - ModelType: "text", - BaseModel: false, - SupportsStreaming: true, - }, - { - ModelID: "meta.llama2-13b-chat-v1", - ModelFamily: "llama", - ModelType: "text", - BaseModel: true, - SupportsStreaming: true, - }, - { - ModelID: "meta.llama2-70b-chat-v1", - ModelFamily: "llama", - ModelType: "text", - BaseModel: false, - SupportsStreaming: true, - }, - { - ModelID: "amazon.titan-text-lite-v1", - ModelFamily: "titan", - ModelType: "text", - BaseModel: true, - SupportsStreaming: false, - }, - { - ModelID: "amazon.titan-text-express-v1", - ModelFamily: "titan", - ModelType: "text", - BaseModel: false, - SupportsStreaming: false, - }, - { - ModelID: "amazon.titan-image-generator-v1", - ModelFamily: "titan-image", - ModelType: "image", - BaseModel: true, - SupportsStreaming: false, - }, - { - ModelID: "stability.stable-diffusion-xl-v1", - ModelFamily: "stability", - ModelType: "image", - BaseModel: true, - SupportsStreaming: false, - }, - { - ModelID: "stability.stable-diffusion-xl-v0", - ModelFamily: "stability", - ModelType: "image", - BaseModel: false, - SupportsStreaming: false, - }, -} - -func GetModel(modelId string) (Model, error) { - - var m Model - - // validate the model is supported - idx := slices.IndexFunc(models, func(m Model) bool { return m.ModelID == modelId }) - if idx == -1 { - // check if its a family shorthand - fam := slices.IndexFunc(models, func(m Model) bool { - return (m.ModelFamily == modelId) && (m.BaseModel) - }) - if fam == -1 { - return m, fmt.Errorf("model id not currently supported: %s", modelId) - } - return models[fam], nil - } - - // return associated model family and model id - return models[idx], nil -} diff --git a/repository/base.go b/repository/base.go new file mode 100644 index 0000000..19350bb --- /dev/null +++ b/repository/base.go @@ -0,0 +1,18 @@ +// repository/base.go +package repository + +import "github.com/chat-cli/chat-cli/db" + +// Repository defines the standard operations to be implemented by all repositories +type Repository[T any] interface { + Create(entity *T) error + GetByID(id int) (*T, error) + Update(entity *T) error + Delete(id int) error + List() ([]T, error) +} + +// BaseRepository provides common functionality for all repositories +type BaseRepository struct { + db db.Database +} diff --git a/repository/chat.go b/repository/chat.go new file mode 100644 index 0000000..3cf1bd7 --- /dev/null +++ b/repository/chat.go @@ -0,0 +1,103 @@ +// repository/chat.go +package repository + +import ( + "fmt" + + "github.com/chat-cli/chat-cli/db" +) + +type Chat struct { + ID int + ChatId string + Persona string + Message string + Created string +} + +// ChatRepository implements Repository interface for Chat +type ChatRepository struct { + BaseRepository +} + +func NewChatRepository(db db.Database) *ChatRepository { + return &ChatRepository{ + BaseRepository: BaseRepository{db: db}, + } +} + +func (r *ChatRepository) Create(chat *Chat) error { + query := ` + INSERT INTO chats (chat_id, persona, message) + VALUES ($1, $2, $3) + RETURNING id` + + err := r.db.GetDB().QueryRow(query, chat.ChatId, chat.Persona, chat.Message).Scan(&chat.ID) + if err != nil { + return fmt.Errorf("error creating user: %v", err) + } + return nil +} + +// Function to list 10 most recent chats +func (r *ChatRepository) List() ([]Chat, error) { + query := ` + SELECT id, chat_id, persona, message, created_at + FROM chats + GROUP BY chat_id + ORDER BY id DESC + LIMIT 10` + + rows, err := r.db.GetDB().Query(query) + if err != nil { + return nil, fmt.Errorf("error listing chats: %v", err) + } + defer rows.Close() + + var chats []Chat + for rows.Next() { + var chat Chat + err := rows.Scan(&chat.ID, &chat.ChatId, &chat.Persona, &chat.Message, &chat.Created) + if err != nil { + return nil, fmt.Errorf("error scanning chat: %v", err) + } + chats = append(chats, chat) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating over chats: %v", err) + } + + return chats, nil +} + +// function to retrieve all messages for a given chat_id +func (r *ChatRepository) GetMessages(chatId string) ([]Chat, error) { + query := ` + SELECT id, chat_id, persona, message + FROM chats + WHERE chat_id = $1 + ORDER BY id ASC` + + rows, err := r.db.GetDB().Query(query, chatId) + if err != nil { + return nil, fmt.Errorf("error retrieving messages: %v", err) + } + defer rows.Close() + + var chats []Chat + for rows.Next() { + var chat Chat + err := rows.Scan(&chat.ID, &chat.ChatId, &chat.Persona, &chat.Message) + if err != nil { + return nil, fmt.Errorf("error scanning chat: %v", err) + } + chats = append(chats, chat) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating over chats: %v", err) + } + + return chats, nil +} diff --git a/utils/utils.go b/utils/utils.go new file mode 100644 index 0000000..31b0238 --- /dev/null +++ b/utils/utils.go @@ -0,0 +1,154 @@ +package utils + +import ( + "bufio" + "context" + "encoding/base64" + "fmt" + "io" + "os" + "path/filepath" + "strings" + + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types" + "github.com/mattn/go-isatty" +) + +type StreamingOutputHandler func(ctx context.Context, part string) error + +func ProcessStreamingOutput(output *bedrockruntime.ConverseStreamOutput, handler StreamingOutputHandler) (types.Message, error) { + + var combinedResult string + + msg := types.Message{} + + for event := range output.GetStream().Events() { + switch v := event.(type) { + case *types.ConverseStreamOutputMemberMessageStart: + + msg.Role = v.Value.Role + + case *types.ConverseStreamOutputMemberContentBlockDelta: + + textResponse := v.Value.Delta.(*types.ContentBlockDeltaMemberText) + handler(context.Background(), textResponse.Value) + combinedResult = combinedResult + textResponse.Value + + case *types.UnknownUnionMember: + fmt.Println("unknown tag:", v.Tag) + } + } + + msg.Content = append(msg.Content, + &types.ContentBlockMemberText{ + Value: combinedResult, + }, + ) + + return msg, nil +} + +func ReadImage(filename string) ([]byte, string, error) { + + // Define a base directory for allowed images + baseDir, err := os.Getwd() + if err != nil { + return nil, "", fmt.Errorf("unable to get working directory: %w", err) + } + + // Clean the filename and create the full path + cleanFilename := filepath.Clean(filename) + fullPath := filepath.Join(baseDir, cleanFilename) + + // Ensure the full path is within the base directory + relPath, err := filepath.Rel(baseDir, fullPath) + if err != nil || strings.HasPrefix(relPath, "..") || strings.HasPrefix(relPath, string(filepath.Separator)) { + return nil, "", fmt.Errorf("access denied: %s is outside of the allowed directory", filename) + } + + // Check if the file exists + if _, err := os.Stat(fullPath); os.IsNotExist(err) { + return nil, "", fmt.Errorf("file does not exist: %s", filename) + } + + // Read the file + data, err := os.ReadFile(fullPath) + if err != nil { + return nil, "", fmt.Errorf("unable to read file: %w", err) + } + + ext := strings.ToLower(filepath.Ext(filename)) + if ext != "" { + ext = ext[1:] // Remove the leading dot + } + + var imageType string + + switch ext { + case "jpg": + imageType = "jpeg" + case "jpeg": + imageType = "jpeg" + case "png": + imageType = "png" + case "gif": + imageType = "gif" + case "webp": + imageType = "webp" + default: + return nil, "", fmt.Errorf("unsupported file type") + + } + + return data, imageType, nil +} + +func StringPrompt(label string) string { + + var s string + bufferSize := 8192 + + r := bufio.NewReaderSize(os.Stdin, bufferSize) + + for { + fmt.Fprint(os.Stderr, label+" ") + s, _ = r.ReadString('\n') + if s != "" { + break + } + } + + return s +} + +func DecodeImage(base64Image string) ([]byte, error) { + decoded, err := base64.StdEncoding.DecodeString(base64Image) + if err != nil { + return nil, err + } + return decoded, nil +} + +func LoadDocument() (string, error) { + + // read a document from stdin + var document string + + if isatty.IsTerminal(os.Stdin.Fd()) || isatty.IsCygwinTerminal(os.Stdin.Fd()) { + // do nothing + } else { + stdin, err := io.ReadAll(os.Stdin) + + if err != nil { + return "", err + } + document = string(stdin) + } + + if document != "" { + document = "\n\n" + document + "\n\n\n\n" + } + + return document, nil +}