From 4905fd686d8692203977315bf63d267d28a0b61d Mon Sep 17 00:00:00 2001 From: Adrian Thurston Date: Tue, 24 Jun 2025 14:28:09 -0700 Subject: [PATCH] feat: allow the user to stop processing flags after seeing N args We have switched from v2 to v3 mainly because we want to support parsing flags after the arguments. There are many other benefits, but this one drove the work. We do have some commands where the old v2 behaviour is desirable, specifically, when we pass arguments through. For example our ssh command: tool ssh MACHINE-NAME ls -la To support this on a per-command basis, added an option to the Command struct called StopOnNthArg. It terminates flag parsing like '--'. --- command.go | 8 + command_parse.go | 7 + command_run.go | 5 + command_stop_on_nth_arg_test.go | 301 ++++++++++++++++++++++++++++++++ command_test.go | 24 ++- godoc-current.txt | 8 + testdata/godoc-v3.x.txt | 8 + 7 files changed, 353 insertions(+), 8 deletions(-) create mode 100644 command_stop_on_nth_arg_test.go diff --git a/command.go b/command.go index 541081a59c..0505f64885 100644 --- a/command.go +++ b/command.go @@ -127,6 +127,14 @@ type Command struct { // Whether to read arguments from stdin // applicable to root command only ReadArgsFromStdin bool `json:"readArgsFromStdin"` + // StopOnNthArg provides v2-like behavior for specific commands by stopping + // flag parsing after N positional arguments are encountered. When set to N, + // all remaining arguments after the Nth positional argument will be treated + // as arguments, not flags. + // + // A value of 0 means all arguments are treated as positional (no flag parsing). + // A nil value means normal v3 flag parsing behavior (flags can appear anywhere). + StopOnNthArg *int `json:"stopOnNthArg"` // categories contains the categorized commands and is populated on app startup categories CommandCategories diff --git a/command_parse.go b/command_parse.go index 1130658005..7e3e0ab357 100644 --- a/command_parse.go +++ b/command_parse.go @@ -89,6 +89,13 @@ func (cmd *Command) parseFlags(args Args) (Args, error) { return &stringSliceArgs{posArgs}, nil } + // Check if we've reached the Nth argument and should stop flag parsing + if cmd.StopOnNthArg != nil && len(posArgs) == *cmd.StopOnNthArg { + // Append current arg and all remaining args without parsing + posArgs = append(posArgs, rargs[0:]...) + return &stringSliceArgs{posArgs}, nil + } + // handle positional args if firstArg[0] != '-' { // positional argument probably diff --git a/command_run.go b/command_run.go index 24b7935166..f91ce4336a 100644 --- a/command_run.go +++ b/command_run.go @@ -99,6 +99,11 @@ func (cmd *Command) run(ctx context.Context, osArgs []string) (_ context.Context tracef("running with arguments %[1]q (cmd=%[2]q)", osArgs, cmd.Name) cmd.setupDefaults(osArgs) + // Validate StopOnNthArg + if cmd.StopOnNthArg != nil && *cmd.StopOnNthArg < 0 { + return ctx, fmt.Errorf("StopOnNthArg must be non-negative, got %d", *cmd.StopOnNthArg) + } + if v, ok := ctx.Value(commandContextKey).(*Command); ok { tracef("setting parent (cmd=%[1]q) command from context.Context value (cmd=%[2]q)", v.Name, cmd.Name) cmd.parent = v diff --git a/command_stop_on_nth_arg_test.go b/command_stop_on_nth_arg_test.go new file mode 100644 index 0000000000..72215385f6 --- /dev/null +++ b/command_stop_on_nth_arg_test.go @@ -0,0 +1,301 @@ +package cli + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCommand_StopOnNthArg(t *testing.T) { + tests := []struct { + name string + stopOnNthArg *int + testArgs []string + expectedArgs []string + expectedFlag string + expectedBool bool + }{ + { + name: "nil StopOnNthArg - normal parsing", + stopOnNthArg: nil, + testArgs: []string{"cmd", "--flag", "value", "arg1", "--bool", "arg2"}, + expectedArgs: []string{"arg1", "arg2"}, + expectedFlag: "value", + expectedBool: true, + }, + { + name: "stop after 0 args - all become args", + stopOnNthArg: intPtr(0), + testArgs: []string{"cmd", "--flag", "value", "arg1", "--bool", "arg2"}, + expectedArgs: []string{"--flag", "value", "arg1", "--bool", "arg2"}, + expectedFlag: "", + expectedBool: false, + }, + { + name: "stop after 1 arg", + stopOnNthArg: intPtr(1), + testArgs: []string{"cmd", "--flag", "value", "arg1", "--bool", "arg2"}, + expectedArgs: []string{"arg1", "--bool", "arg2"}, + expectedFlag: "value", + expectedBool: false, + }, + { + name: "stop after 2 args", + stopOnNthArg: intPtr(2), + testArgs: []string{"cmd", "--flag", "value", "arg1", "arg2", "--bool", "arg3"}, + expectedArgs: []string{"arg1", "arg2", "--bool", "arg3"}, + expectedFlag: "value", + expectedBool: false, + }, + { + name: "mixed flags and args - stop after 1", + stopOnNthArg: intPtr(1), + testArgs: []string{"cmd", "--flag", "value", "--bool", "arg1", "--flag2", "value2"}, + expectedArgs: []string{"arg1", "--flag2", "value2"}, + expectedFlag: "value", + expectedBool: true, + }, + { + name: "args before flags - stop after 1", + stopOnNthArg: intPtr(1), + testArgs: []string{"cmd", "arg1", "--flag", "value", "--bool"}, + expectedArgs: []string{"arg1", "--flag", "value", "--bool"}, + expectedFlag: "", + expectedBool: false, + }, + { + name: "ssh command example", + stopOnNthArg: intPtr(1), + testArgs: []string{"ssh", "machine-name", "ls", "-la"}, + expectedArgs: []string{"machine-name", "ls", "-la"}, + expectedFlag: "", + expectedBool: false, + }, + { + name: "with double dash terminator", + stopOnNthArg: intPtr(1), + testArgs: []string{"cmd", "--flag", "value", "--", "arg1", "--not-a-flag"}, + expectedArgs: []string{"arg1", "--not-a-flag"}, + expectedFlag: "value", + expectedBool: false, + }, + { + name: "stop after large number of args", + stopOnNthArg: intPtr(100), + testArgs: []string{"cmd", "--flag", "value", "arg1", "arg2", "--bool"}, + expectedArgs: []string{"arg1", "arg2"}, + expectedFlag: "value", + expectedBool: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var args Args + var flagValue string + var boolValue bool + + cmd := &Command{ + Name: "test", + StopOnNthArg: tt.stopOnNthArg, + Flags: []Flag{ + &StringFlag{Name: "flag", Destination: &flagValue}, + &StringFlag{Name: "flag2"}, + &BoolFlag{Name: "bool", Destination: &boolValue}, + }, + Action: func(_ context.Context, cmd *Command) error { + args = cmd.Args() + return nil + }, + } + + require.NoError(t, cmd.Run(buildTestContext(t), tt.testArgs)) + assert.Equal(t, tt.expectedArgs, args.Slice()) + assert.Equal(t, tt.expectedFlag, flagValue) + assert.Equal(t, tt.expectedBool, boolValue) + }) + } +} + +func TestCommand_StopOnNthArg_WithSubcommands(t *testing.T) { + tests := []struct { + name string + parentStopOnNthArg *int + subStopOnNthArg *int + testArgs []string + expectedParentArgs []string + expectedSubArgs []string + expectedSubFlag string + }{ + { + name: "parent normal, subcommand stops after 0", + parentStopOnNthArg: nil, + subStopOnNthArg: intPtr(0), + testArgs: []string{"parent", "sub", "--subflag", "value", "subarg", "--not-parsed"}, + expectedParentArgs: []string{}, + expectedSubArgs: []string{"--subflag", "value", "subarg", "--not-parsed"}, + expectedSubFlag: "", + }, + { + name: "parent normal, subcommand stops after 1", + parentStopOnNthArg: nil, + subStopOnNthArg: intPtr(1), + testArgs: []string{"parent", "sub", "--subflag", "value", "subarg", "--not-parsed"}, + expectedParentArgs: []string{}, + expectedSubArgs: []string{"subarg", "--not-parsed"}, + expectedSubFlag: "value", + }, + { + name: "parent normal, subcommand stops after 2", + parentStopOnNthArg: nil, + subStopOnNthArg: intPtr(2), + testArgs: []string{"parent", "sub", "--subflag", "value", "subarg1", "subarg2", "--not-parsed"}, + expectedParentArgs: []string{}, + expectedSubArgs: []string{"subarg1", "subarg2", "--not-parsed"}, + expectedSubFlag: "value", + }, + { + name: "parent normal, subcommand never stops (high StopOnNthArg)", + parentStopOnNthArg: nil, + subStopOnNthArg: intPtr(100), + testArgs: []string{"parent", "sub", "--subflag", "value1", "arg1", "arg2", "--subflag", "value2"}, + expectedParentArgs: []string{}, + expectedSubArgs: []string{"arg1", "arg2"}, + expectedSubFlag: "value2", // Should parse the second --subflag since we never hit the stop limit + }, + { + // Meaningless, but okay. + name: "parent stops after 1, subcommand stops after 1", + parentStopOnNthArg: intPtr(1), + subStopOnNthArg: intPtr(1), + testArgs: []string{"parent", "sub", "--subflag", "value", "subarg", "--not-parsed"}, + expectedParentArgs: []string{}, + expectedSubArgs: []string{"subarg", "--not-parsed"}, + expectedSubFlag: "value", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var parentArgs, subArgs Args + var subFlagValue string + subCalled := false + + subCmd := &Command{ + Name: "sub", + StopOnNthArg: tt.subStopOnNthArg, + Flags: []Flag{ + &StringFlag{Name: "subflag", Destination: &subFlagValue}, + }, + Action: func(_ context.Context, cmd *Command) error { + subCalled = true + subArgs = cmd.Args() + return nil + }, + } + + parentCmd := &Command{ + Name: "parent", + StopOnNthArg: tt.parentStopOnNthArg, + Commands: []*Command{subCmd}, + Flags: []Flag{ + &StringFlag{Name: "parentflag"}, + }, + Action: func(_ context.Context, cmd *Command) error { + parentArgs = cmd.Args() + return nil + }, + } + + err := parentCmd.Run(buildTestContext(t), tt.testArgs) + + require.NoError(t, err) + + if tt.expectedSubArgs != nil { + assert.True(t, subCalled, "subcommand should have been called") + if len(tt.expectedSubArgs) > 0 { + haveNonEmptySubArgsSlice := subArgs != nil && subArgs.Slice() != nil && len(subArgs.Slice()) > 0 + assert.True(t, haveNonEmptySubArgsSlice, "subargs.Slice is not nil") + if haveNonEmptySubArgsSlice { + assert.Equal(t, tt.expectedSubArgs, subArgs.Slice()) + } + } else { + assert.True(t, subArgs == nil || subArgs.Slice() == nil || len(subArgs.Slice()) == 0, "subargs.Slice is not nil") + } + assert.Equal(t, tt.expectedSubFlag, subFlagValue) + } else { + assert.False(t, subCalled, "subcommand should not have been called") + assert.Equal(t, tt.expectedParentArgs, parentArgs.Slice()) + } + }) + } +} + +func TestCommand_StopOnNthArg_EdgeCases(t *testing.T) { + t.Run("negative StopOnNthArg returns error", func(t *testing.T) { + cmd := &Command{ + Name: "test", + StopOnNthArg: intPtr(-1), + Action: func(_ context.Context, cmd *Command) error { + return nil + }, + } + + // Negative value should return an error + err := cmd.Run(buildTestContext(t), []string{"cmd", "arg1"}) + require.Error(t, err) + assert.Contains(t, err.Error(), "StopOnNthArg must be non-negative") + }) + + t.Run("zero StopOnNthArg with no args", func(t *testing.T) { + var args Args + var flagValue string + cmd := &Command{ + Name: "test", + StopOnNthArg: intPtr(0), + Flags: []Flag{ + &StringFlag{Name: "flag", Destination: &flagValue}, + }, + Action: func(_ context.Context, cmd *Command) error { + args = cmd.Args() + return nil + }, + } + + // All flags should become args + require.NoError(t, cmd.Run(buildTestContext(t), []string{"cmd", "--flag", "value"})) + assert.Equal(t, []string{"--flag", "value"}, args.Slice()) + assert.Equal(t, "", flagValue) + }) + + t.Run("StopOnNthArg with only flags", func(t *testing.T) { + var args Args + var flagValue string + var boolValue bool + cmd := &Command{ + Name: "test", + StopOnNthArg: intPtr(1), + Flags: []Flag{ + &StringFlag{Name: "flag", Destination: &flagValue}, + &BoolFlag{Name: "bool", Destination: &boolValue}, + }, + Action: func(_ context.Context, cmd *Command) error { + args = cmd.Args() + return nil + }, + } + + // Should parse all flags since no args are encountered + require.NoError(t, cmd.Run(buildTestContext(t), []string{"cmd", "--flag", "value", "--bool"})) + assert.Equal(t, []string{}, args.Slice()) + assert.Equal(t, "value", flagValue) + assert.True(t, boolValue) + }) +} + +// Helper function to create int pointer +func intPtr(i int) *int { + return &i +} diff --git a/command_test.go b/command_test.go index b9e69bc307..a1f6641a63 100644 --- a/command_test.go +++ b/command_test.go @@ -4743,7 +4743,8 @@ func TestJSONExportCommand(t *testing.T) { "prefixMatchCommands": false, "mutuallyExclusiveFlags": null, "arguments": null, - "readArgsFromStdin": false + "readArgsFromStdin": false, + "stopOnNthArg": null } ], "flags": [ @@ -4805,7 +4806,8 @@ func TestJSONExportCommand(t *testing.T) { "prefixMatchCommands": false, "mutuallyExclusiveFlags": null, "arguments": null, - "readArgsFromStdin": false + "readArgsFromStdin": false, + "stopOnNthArg": null }, { "name": "info", @@ -4838,7 +4840,8 @@ func TestJSONExportCommand(t *testing.T) { "prefixMatchCommands": false, "mutuallyExclusiveFlags": null, "arguments": null, - "readArgsFromStdin": false + "readArgsFromStdin": false, + "stopOnNthArg": null }, { "name": "some-command", @@ -4868,7 +4871,8 @@ func TestJSONExportCommand(t *testing.T) { "prefixMatchCommands": false, "mutuallyExclusiveFlags": null, "arguments": null, - "readArgsFromStdin": false + "readArgsFromStdin": false, + "stopOnNthArg": null }, { "name": "hidden-command", @@ -4917,7 +4921,8 @@ func TestJSONExportCommand(t *testing.T) { "prefixMatchCommands": false, "mutuallyExclusiveFlags": null, "arguments": null, - "readArgsFromStdin": false + "readArgsFromStdin": false, + "stopOnNthArg": null }, { "name": "usage", @@ -4983,7 +4988,8 @@ func TestJSONExportCommand(t *testing.T) { "prefixMatchCommands": false, "mutuallyExclusiveFlags": null, "arguments": null, - "readArgsFromStdin": false + "readArgsFromStdin": false, + "stopOnNthArg": null } ], "flags": [ @@ -5045,7 +5051,8 @@ func TestJSONExportCommand(t *testing.T) { "prefixMatchCommands": false, "mutuallyExclusiveFlags": null, "arguments": null, - "readArgsFromStdin": false + "readArgsFromStdin": false, + "stopOnNthArg": null } ], "flags": [ @@ -5162,7 +5169,8 @@ func TestJSONExportCommand(t *testing.T) { } } ], - "readArgsFromStdin": false + "readArgsFromStdin": false, + "stopOnNthArg": null } ` assert.JSONEq(t, expected, string(out)) diff --git a/godoc-current.txt b/godoc-current.txt index edd42feb88..f2d548fb21 100644 --- a/godoc-current.txt +++ b/godoc-current.txt @@ -514,6 +514,14 @@ type Command struct { // Whether to read arguments from stdin // applicable to root command only ReadArgsFromStdin bool `json:"readArgsFromStdin"` + // StopOnNthArg provides v2-like behavior for specific commands by stopping + // flag parsing after N positional arguments are encountered. When set to N, + // all remaining arguments after the Nth positional argument will be treated + // as arguments, not flags. + // + // A value of 0 means all arguments are treated as positional (no flag parsing). + // A nil value means normal v3 flag parsing behavior (flags can appear anywhere). + StopOnNthArg *int `json:"stopOnNthArg"` // Has unexported fields. } diff --git a/testdata/godoc-v3.x.txt b/testdata/godoc-v3.x.txt index edd42feb88..f2d548fb21 100644 --- a/testdata/godoc-v3.x.txt +++ b/testdata/godoc-v3.x.txt @@ -514,6 +514,14 @@ type Command struct { // Whether to read arguments from stdin // applicable to root command only ReadArgsFromStdin bool `json:"readArgsFromStdin"` + // StopOnNthArg provides v2-like behavior for specific commands by stopping + // flag parsing after N positional arguments are encountered. When set to N, + // all remaining arguments after the Nth positional argument will be treated + // as arguments, not flags. + // + // A value of 0 means all arguments are treated as positional (no flag parsing). + // A nil value means normal v3 flag parsing behavior (flags can appear anywhere). + StopOnNthArg *int `json:"stopOnNthArg"` // Has unexported fields. }