8000 feat!: support multiple generative models by reugn · Pull Request #29 · reugn/gemini-cli · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

feat!: support multiple generative models #29

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,12 @@ If you don't already have one, create a key in [Google AI Studio](https://makers
The system chat message must begin with an exclamation mark and is used for internal operations.
A short list of supported system commands:

| Command | Description
| --- | ---
| !q | Quit the application
| !p | Delete the history used as chat context by the model
| !m | Toggle input mode (single-line <-> multi-line)
| Command | Description |
|---------|------------------------------------------------------|
| !q | Quit the application |
| !p | Delete the history used as chat context by the model |
| !i | Toggle input mode (single-line <-> multi-line) |
| !m | Select generative model |

### CLI help
```console
Expand All @@ -54,7 +55,8 @@ Usage:
Flags:
-f, --format render markdown-formatted response (default true)
-h, --help help for this command
-m, --multiline read input as a multi-line string
-m, --model string generative model name (default "gemini-pro")
--multiline read input as a multi-line string
-s, --style string markdown format style (ascii, dark, light, pink, notty, dracula) (default "auto")
-t, --term string multi-line input terminator (default "$")
-v, --version version for this command
Expand Down
5 changes: 3 additions & 2 deletions cmd/gemini/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,16 @@ func run() int {
}

var opts cli.ChatOpts
rootCmd.Flags().StringVarP(&opts.Model, "model", "m", gemini.DefaultModel, "generative model name")
rootCmd.Flags().BoolVarP(&opts.Format, "format", "f", true, "render markdown-formatted response")
rootCmd.Flags().StringVarP(&opts.Style, "style", "s", "auto",
"markdown format style (ascii, dark, light, pink, notty, dracula)")
rootCmd.Flags().BoolVarP(&opts.Multiline, "multiline", "m", false, "read input as a multi-line string")
rootCmd.Flags().BoolVar(&opts.Multiline, "multiline", false, "read input as a multi-line string")
rootCmd.Flags().StringVarP(&opts.Terminator, "term", "t", "$", "multi-line input terminator")

rootCmd.RunE = func(_ *cobra.Command, _ []string) error {
apiKey := os.Getenv(apiKeyEnv)
chatSession, err := gemini.NewChatSession(context.Background(), apiKey)
chatSession, err := gemini.NewChatSession(context.Background(), opts.Model, apiKey)
if err != nil {
return err
}
Expand Down
44 changes: 38 additions & 6 deletions gemini/chat_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,36 @@ package gemini

import (
"context"
"sync"

"github.com/google/generative-ai-go/genai"
"google.golang.org/api/option"
)

// ChatSession represents a gemini-pro powered chat session.
const DefaultModel = "gemini-pro"

// ChatSession represents a gemini powered chat session.
type ChatSession struct {
ctx context.Context
ctx context.Context

client *genai.Client
session *genai.ChatSession

loadModels sync.Once
models []string
}

// NewChatSession returns a new ChatSession.
func NewChatSession(ctx context.Context, apiKey string) (*ChatSession, error) {
// NewChatSession returns a new [ChatSession].
func NewChatSession(ctx context.Context, model, apiKey string) (*ChatSession, error) {
client, err := genai.NewClient(ctx, option.WithAPIKey(apiKey))
if err != nil {
return nil, err
}

return &ChatSession{
ctx: ctx,
client: client,
session: client.GenerativeModel("gemini-pro").StartChat(),
session: client.GenerativeModel(model).StartChat(),
}, nil
}

Expand All @@ -37,12 +45,36 @@ func (c *ChatSession) SendMessageStream(input string) *genai.GenerateContentResp
return c.session.SendMessageStream(c.ctx, genai.Text(input))
}

// SetGenerativeModel sets the name of the generative model for the chat.
// It preserves the history from the previous chat session.
func (c *ChatSession) SetGenerativeModel(model string) {
history := c.session.History
c.session = c.client.GenerativeModel(model).StartChat()
c.session.History = history
}

// ListModels returns a list of the supported generative model names.
func (c *ChatSession) ListModels() []string {
c.loadModels.Do(func() {
c.models = []string{DefaultModel}
iter := c.client.ListModels(c.ctx)
for {
modelInfo, err := iter.Next()
if err != nil {
break
}
c.models = append(c.models, modelInfo.Name)
}
})
return c.models
}

// ClearHistory clears chat history.
func (c *ChatSession) ClearHistory() {
c.session.History = make([]*genai.Content, 0)
}

// Close closes the genai.Client.
// Close closes the chat session.
func (c *ChatSession) Close() error {
return c.client.Close()
}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ require (
github.com/charmbracelet/glamour v0.7.0
github.com/chzyer/readline v1.5.1
github.com/google/generative-ai-go v0.18.0
github.com/manifoldco/promptui v0.9.0
github.com/muesli/termenv v0.15.2
github.com/spf13/cobra v1.8.1
google.golang.org/api v0.196.0
Expand Down
6 changes: 6 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,13 @@ github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd3
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/charmbracelet/glamour v0.7.0 h1:2BtKGZ4iVJCDfMF229EzbeR1QRKLWztO9dMtjmqZSng=
github.com/charmbracelet/glamour v0.7.0/go.mod h1:jUMh5MeihljJPQbJ/wf4ldw2+yBP59+ctV36jASy7ps=
github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI=
github.com/chzyer/logex v1.2.1 h1:XHDu3E6q+gdHgsdTPH6ImJMIp436vR6MPtH8gP05QzM=
github.com/chzyer/logex v1.2.1/go.mod h1:JLbx6lG2kDbNRFnfkgvh4eRJRPX1QCoOIWomwysCBrQ=
github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI=
github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI=
github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk=
github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04=
github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8=
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
Expand Down Expand Up @@ -93,6 +96,8 @@ github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/manifoldco/promptui v0.9.0 h1:3V4HzJk1TtXW1MTZMP7mdlwbBpIinw3HztaIlYthEiA=
github.com/manifoldco/promptui v0.9.0/go.mod h1:ka04sppxSGFAtxX0qhlYQjISsg9mR4GWtQEhdbn6Pgg=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
Expand Down Expand Up @@ -169,6 +174,7 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181122145206-62eef0e2fa9b/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
Expand Down
1 change: 1 addition & 0 deletions internal/cli/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

// ChatOpts represents Chat configuration options.
type ChatOpts struct {
Model string
Format bool
Style string
Multiline bool
Expand Down
50 changes: 40 additions & 10 deletions internal/cli/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ const (
systemCmdPrefix = "!"
systemCmdQuit = "!q"
systemCmdPurgeHistory = "!p"
systemCmdToggleInputMode = "!m"
systemCmdSelectInputMode = "!i"
systemCmdSelectModel = "!m"
)

type command interface {
Expand All @@ -44,18 +45,39 @@ func (c *systemCommand) run(message string) bool {
case systemCmdPurgeHistory:
c.chat.model.ClearHistory()
c.print("Cleared the chat history.")
case systemCmdToggleInputMode:
case systemCmdSelectInputMode:
multiline, err := selectInputMode(c.chat.opts.Multiline)
if err != nil {
c.error(err)
break
}
if multiline == c.chat.opts.Multiline {
c.printSelectedCurrent()
break
}
c.chat.opts.Multiline = multiline
if c.chat.opts.Multiline {
c.print(" F438 ;Switched to single-line input mode.")
c.chat.reader.HistoryEnable()
c.chat.opts.Multiline = false
} else {
c.print("Switched to multi-line input mode.")
// disable history for multi-line messages since it is
// unusable for future requests
c.chat.reader.HistoryDisable()
c.chat.opts.Multiline = true
} else {
c.print("Switched to single-line input mode.")
c.chat.reader.HistoryEnable()
}
case systemCmdSelectModel:
model, err := selectModel(c.chat.opts.Model, c.chat.model.ListModels())
if err != nil {
c.error(err)
break
}
if model == c.chat.opts.Model {
c.printSelectedCurrent()
break
}
c.chat.opts.Model = model
c.chat.model.SetGenerativeModel(model)
c.print(fmt.Sprintf("Selected '%s' generative model.", model))
default:
c.print("Unknown system command.")
}
Expand All @@ -66,6 +88,14 @@ func (c *systemCommand) print(message string) {
fmt.Printf("%s%s\n", c.chat.prompt.cli, message)
}

func (c *systemCommand) printSelectedCurrent() {
fmt.Printf("%sThe selection is unchanged.\n", c.chat.prompt.cli)
}

func (c *systemCommand) error(err error) {
fmt.Printf(color.Red("%s%s\n"), c.chat.prompt.cli, err)
}

type geminiCommand struct {
chat *Chat
spinner *spinner
Expand Down Expand Up @@ -104,7 +134,7 @@ func (c *geminiCommand) runBlocking(message string) {
var buf strings.Builder
for _, candidate := range response.Candidates {
for _, part := range candidate.Content.Parts {
buf.WriteString(fmt.Sprintf("%s", part))
fmt.Fprintf(&buf, "%s", part)
}
}
output, err := glamour.Render(buf.String(), c.chat.opts.Style)
Expand Down Expand Up @@ -138,6 +168,6 @@ func (c *geminiCommand) runStreaming(message string) {
}

func (c *geminiCommand) printFlush(message string) {
fmt.Fprintf(c.writer, "%s", message)
c.writer.Flush()
_, _ = c.writer.WriteString(message)
_ = c.writer.Flush()
}
51 changes: 51 additions & 0 deletions internal/cli/select.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package cli

import (
"slices"

"github.com/manifoldco/promptui"
)

var (
inputMode = []string{"single-line", "multi-line"}
)

// selectModel returns the selected generative model name.
func selectModel(current string, models []string) (string, error) {
prompt := promptui.Select{
Label: "Select generative model",
HideSelected: true,
Items: models,
CursorPos: slices.Index(models, current),
}

_, result, err := prompt.Run()
if err != nil {
return "", err
}

return result, nil
}

// selectInputMode returns true if multiline input is selected;
// otherwise, it returns false.
func selectInputMode(multiline bool) (bool, error) {
var cursorPos int
if multiline {
cursorPos = 1
}

prompt := promptui.Select{
Label: "Select input mode",
HideSelected: true,
Items: inputMode,
CursorPos: cursorPos,
}

_, result, err := prompt.Run()
if err != nil {
return false, err
}

return result == inputMode[1], nil
}
Loading
0