From fcbca80aa63b71981ecf8e344ed5c409206901bc Mon Sep 17 00:00:00 2001 From: reugn Date: Wed, 23 Oct 2024 11:54:39 +0300 Subject: [PATCH] feat!: support multiple generative models --- README.md | 14 ++++++----- cmd/gemini/main.go | 5 ++-- gemini/chat_session.go | 44 ++++++++++++++++++++++++++++++----- go.mod | 1 + go.sum | 6 +++++ internal/cli/chat.go | 1 + internal/cli/command.go | 50 ++++++++++++++++++++++++++++++++-------- internal/cli/select.go | 51 +++++++++++++++++++++++++++++++++++++++++ 8 files changed, 148 insertions(+), 24 deletions(-) create mode 100644 internal/cli/select.go diff --git a/README.md b/README.md index bf6cfe5..47a1a76 100644 --- a/README.md +++ b/README.md @@ -37,11 +37,12 @@ If you don't already have one, create a key in [Google AI Studio](https://makers The system chat message must begin with an exclamation mark and is used for internal operations. A short list of supported system commands: -| Command | Description -| --- | --- -| !q | Quit the application -| !p | Delete the history used as chat context by the model -| !m | Toggle input mode (single-line <-> multi-line) +| Command | Description | +|---------|------------------------------------------------------| +| !q | Quit the application | +| !p | Delete the history used as chat context by the model | +| !i | Toggle input mode (single-line <-> multi-line) | +| !m | Select generative model | ### CLI help ```console @@ -54,7 +55,8 @@ Usage: Flags: -f, --format render markdown-formatted response (default true) -h, --help help for this command - -m, --multiline read input as a multi-line string + -m, --model string generative model name (default "gemini-pro") + --multiline read input as a multi-line string -s, --style string markdown format style (ascii, dark, light, pink, notty, dracula) (default "auto") -t, --term string multi-line input terminator (default "$") -v, --version version for this command diff --git a/cmd/gemini/main.go b/cmd/gemini/main.go index f13bc78..2c97370 100644 --- a/cmd/gemini/main.go +++ b/cmd/gemini/main.go @@ -22,15 +22,16 @@ func run() int { } var opts cli.ChatOpts + rootCmd.Flags().StringVarP(&opts.Model, "model", "m", gemini.DefaultModel, "generative model name") rootCmd.Flags().BoolVarP(&opts.Format, "format", "f", true, "render markdown-formatted response") rootCmd.Flags().StringVarP(&opts.Style, "style", "s", "auto", "markdown format style (ascii, dark, light, pink, notty, dracula)") - rootCmd.Flags().BoolVarP(&opts.Multiline, "multiline", "m", false, "read input as a multi-line string") + rootCmd.Flags().BoolVar(&opts.Multiline, "multiline", false, "read input as a multi-line string") rootCmd.Flags().StringVarP(&opts.Terminator, "term", "t", "$", "multi-line input terminator") rootCmd.RunE = func(_ *cobra.Command, _ []string) error { apiKey := os.Getenv(apiKeyEnv) - chatSession, err := gemini.NewChatSession(context.Background(), apiKey) + chatSession, err := gemini.NewChatSession(context.Background(), opts.Model, apiKey) if err != nil { return err } diff --git a/gemini/chat_session.go b/gemini/chat_session.go index 9cf76b1..0c5aea6 100644 --- a/gemini/chat_session.go +++ b/gemini/chat_session.go @@ -2,28 +2,36 @@ package gemini import ( "context" + "sync" "github.com/google/generative-ai-go/genai" "google.golang.org/api/option" ) -// ChatSession represents a gemini-pro powered chat session. +const DefaultModel = "gemini-pro" + +// ChatSession represents a gemini powered chat session. type ChatSession struct { - ctx context.Context + ctx context.Context + client *genai.Client session *genai.ChatSession + + loadModels sync.Once + models []string } -// NewChatSession returns a new ChatSession. -func NewChatSession(ctx context.Context, apiKey string) (*ChatSession, error) { +// NewChatSession returns a new [ChatSession]. +func NewChatSession(ctx context.Context, model, apiKey string) (*ChatSession, error) { client, err := genai.NewClient(ctx, option.WithAPIKey(apiKey)) if err != nil { return nil, err } + return &ChatSession{ ctx: ctx, client: client, - session: client.GenerativeModel("gemini-pro").StartChat(), + session: client.GenerativeModel(model).StartChat(), }, nil } @@ -37,12 +45,36 @@ func (c *ChatSession) SendMessageStream(input string) *genai.GenerateContentResp return c.session.SendMessageStream(c.ctx, genai.Text(input)) } +// SetGenerativeModel sets the name of the generative model for the chat. +// It preserves the history from the previous chat session. +func (c *ChatSession) SetGenerativeModel(model string) { + history := c.session.History + c.session = c.client.GenerativeModel(model).StartChat() + c.session.History = history +} + +// ListModels returns a list of the supported generative model names. +func (c *ChatSession) ListModels() []string { + c.loadModels.Do(func() { + c.models = []string{DefaultModel} + iter := c.client.ListModels(c.ctx) + for { + modelInfo, err := iter.Next() + if err != nil { + break + } + c.models = append(c.models, modelInfo.Name) + } + }) + return c.models +} + // ClearHistory clears chat history. func (c *ChatSession) ClearHistory() { c.session.History = make([]*genai.Content, 0) } -// Close closes the genai.Client. +// Close closes the chat session. func (c *ChatSession) Close() error { return c.client.Close() } diff --git a/go.mod b/go.mod index 927cb4f..877224f 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/charmbracelet/glamour v0.7.0 github.com/chzyer/readline v1.5.1 github.com/google/generative-ai-go v0.18.0 + github.com/manifoldco/promptui v0.9.0 github.com/muesli/termenv v0.15.2 github.com/spf13/cobra v1.8.1 google.golang.org/api v0.196.0 diff --git a/go.sum b/go.sum index 0dcf083..0eb01bb 100644 --- a/go.sum +++ b/go.sum @@ -25,10 +25,13 @@ github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd3 github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/charmbracelet/glamour v0.7.0 h1:2BtKGZ4iVJCDfMF229EzbeR1QRKLWztO9dMtjmqZSng= github.com/charmbracelet/glamour v0.7.0/go.mod h1:jUMh5MeihljJPQbJ/wf4ldw2+yBP59+ctV36jASy7ps= +github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/logex v1.2.1 h1:XHDu3E6q+gdHgsdTPH6ImJMIp436vR6MPtH8gP05QzM= github.com/chzyer/logex v1.2.1/go.mod h1:JLbx6lG2kDbNRFnfkgvh4eRJRPX1QCoOIWomwysCBrQ= +github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI= github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk= +github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04= github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= @@ -93,6 +96,8 @@ github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2 github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= +github.com/manifoldco/promptui v0.9.0 h1:3V4HzJk1TtXW1MTZMP7mdlwbBpIinw3HztaIlYthEiA= +github.com/manifoldco/promptui v0.9.0/go.mod h1:ka04sppxSGFAtxX0qhlYQjISsg9mR4GWtQEhdbn6Pgg= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= @@ -169,6 +174,7 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20181122145206-62eef0e2fa9b/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/internal/cli/chat.go b/internal/cli/chat.go index b637c76..e9d6f2f 100644 --- a/internal/cli/chat.go +++ b/internal/cli/chat.go @@ -11,6 +11,7 @@ import ( // ChatOpts represents Chat configuration options. type ChatOpts struct { + Model string Format bool Style string Multiline bool diff --git a/internal/cli/command.go b/internal/cli/command.go index 0ef135e..7215714 100644 --- a/internal/cli/command.go +++ b/internal/cli/command.go @@ -17,7 +17,8 @@ const ( systemCmdPrefix = "!" systemCmdQuit = "!q" systemCmdPurgeHistory = "!p" - systemCmdToggleInputMode = "!m" + systemCmdSelectInputMode = "!i" + systemCmdSelectModel = "!m" ) type command interface { @@ -44,18 +45,39 @@ func (c *systemCommand) run(message string) bool { case systemCmdPurgeHistory: c.chat.model.ClearHistory() c.print("Cleared the chat history.") - case systemCmdToggleInputMode: + case systemCmdSelectInputMode: + multiline, err := selectInputMode(c.chat.opts.Multiline) + if err != nil { + c.error(err) + break + } + if multiline == c.chat.opts.Multiline { + c.printSelectedCurrent() + break + } + c.chat.opts.Multiline = multiline if c.chat.opts.Multiline { - c.print("Switched to single-line input mode.") - c.chat.reader.HistoryEnable() - c.chat.opts.Multiline = false - } else { c.print("Switched to multi-line input mode.") // disable history for multi-line messages since it is // unusable for future requests c.chat.reader.HistoryDisable() - c.chat.opts.Multiline = true + } else { + c.print("Switched to single-line input mode.") + c.chat.reader.HistoryEnable() + } + case systemCmdSelectModel: + model, err := selectModel(c.chat.opts.Model, c.chat.model.ListModels()) + if err != nil { + c.error(err) + break + } + if model == c.chat.opts.Model { + c.printSelectedCurrent() + break } + c.chat.opts.Model = model + c.chat.model.SetGenerativeModel(model) + c.print(fmt.Sprintf("Selected '%s' generative model.", model)) default: c.print("Unknown system command.") } @@ -66,6 +88,14 @@ func (c *systemCommand) print(message string) { fmt.Printf("%s%s\n", c.chat.prompt.cli, message) } +func (c *systemCommand) printSelectedCurrent() { + fmt.Printf("%sThe selection is unchanged.\n", c.chat.prompt.cli) +} + +func (c *systemCommand) error(err error) { + fmt.Printf(color.Red("%s%s\n"), c.chat.prompt.cli, err) +} + type geminiCommand struct { chat *Chat spinner *spinner @@ -104,7 +134,7 @@ func (c *geminiCommand) runBlocking(message string) { var buf strings.Builder for _, candidate := range response.Candidates { for _, part := range candidate.Content.Parts { - buf.WriteString(fmt.Sprintf("%s", part)) + fmt.Fprintf(&buf, "%s", part) } } output, err := glamour.Render(buf.String(), c.chat.opts.Style) @@ -138,6 +168,6 @@ func (c *geminiCommand) runStreaming(message string) { } func (c *geminiCommand) printFlush(message string) { - fmt.Fprintf(c.writer, "%s", message) - c.writer.Flush() + _, _ = c.writer.WriteString(message) + _ = c.writer.Flush() } diff --git a/internal/cli/select.go b/internal/cli/select.go new file mode 100644 index 0000000..d82576c --- /dev/null +++ b/internal/cli/select.go @@ -0,0 +1,51 @@ +package cli + +import ( + "slices" + + "github.com/manifoldco/promptui" +) + +var ( + inputMode = []string{"single-line", "multi-line"} +) + +// selectModel returns the selected generative model name. +func selectModel(current string, models []string) (string, error) { + prompt := promptui.Select{ + Label: "Select generative model", + HideSelected: true, + Items: models, + CursorPos: slices.Index(models, current), + } + + _, result, err := prompt.Run() + if err != nil { + return "", err + } + + return result, nil +} + +// selectInputMode returns true if multiline input is selected; +// otherwise, it returns false. +func selectInputMode(multiline bool) (bool, error) { + var cursorPos int + if multiline { + cursorPos = 1 + } + + prompt := promptui.Select{ + Label: "Select input mode", + HideSelected: true, + Items: inputMode, + CursorPos: cursorPos, + } + + _, result, err := prompt.Run() + if err != nil { + return false, err + } + + return result == inputMode[1], nil +}