8000 BED-5810 - Make stbernard Less Aggressive during Golang Code Generation by zinic · Pull Request #1387 · SpecterOps/BloodHound · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

BED-5810 - Make stbernard Less Aggressive during Golang Code Generation #1387

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 111 additions & 44 deletions packages/go/stbernard/workspace/golang/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,69 +17,136 @@
package golang

import (
"bufio"
"context"
"errors"
"fmt"
"io"
"os"
"path/filepath"
"runtime"
"strings"
"sync"

"github.com/specterops/bloodhound/dawgs/util/channels"
"github.com/specterops/bloodhound/packages/go/stbernard/cmdrunner"
"github.com/specterops/bloodhound/packages/go/stbernard/environment"
)

// WorkspaceGenerate runs go generate ./... for all module paths passed
func WorkspaceGenerate(modPaths []string, env environment.Environment) error {
var (
errs []error
wg sync.WaitGroup
mu sync.Mutex
)
// fileContentContainsGenerationDirective uses a bufio.Scanner to search a given reader line-by-line
// looking for golang code generation directives. Upon finding the first code generation directive
// this function returns true with no error. If no code generation directives exist this function will
// return false instead.
func fileContentContainsGenerationDirective(fin io.Reader) (bool, error) {
scanner := bufio.NewScanner(fin)

for _, modPath := range modPaths {
wg.Add(1)
go func(modPath string) {
defer wg.Done()
if err := moduleGenerate(modPath, env); err != nil {
mu.Lock()
errs = append(errs, fmt.Errorf("code generation for module %s: %w", modPath, err))
mu.Unlock()
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())

if strings.HasPrefix(line, "//go:generate") {
// If we find a go generate directive return right away
return true, nil
}
}

return false, scanner.Err()
}

// packageHasGenerationDirectives scans a golang package at a given path for any files that contain a
// golang code generation directive. Upon finding any file in the package that contains a code
// generation directive, this function returns true with no error. If no code generation directives exist
// in the package, this function returns false instead.
func packageHasGenerationDirectives(packagePath string) (bool, error) {
hasGolangCodeGenDirectives := false

if err := filepath.Walk(packagePath, func(path string, info os.FileInfo, err error) error {
// Don't bother reading anything that isn't a golang source file
if info.IsDir() || filepath.Ext(info.Name()) != ".go" {
return nil
}

// Open the file and search for code generation directives
if fin, err := os.Open(path); err != nil {
return err
} else {
defer fin.Close()

if hasGolangCodeGenDirectives, err = fileContentContainsGenerationDirective(fin); err != nil {
return err
} else if hasGolangCodeGenDirectives {
// Skip the rest of the FS walk for this package
return filepath.SkipAll
}
}(modPath)
}

return nil
}); err != nil {
return false, err
}

wg.Wait()
return hasGolangCodeGenDirectives, nil
}

return errors.Join(errs...)
// parallelGenerateModulePackages spins up runtime.NumCPU() concurrent workers that will attempt golang code generation
// for each GoPackage transmitted over the jobC channel.
func parallelGenerateModulePackages(jobC <-chan GoPackage, waitGroup *sync.WaitGroup, env environment.Environment, addErr func(err error)) {
for workerID := 1; workerID <= runtime.NumCPU(); workerID++ {
waitGroup.Add(1)

go func() {
defer waitGroup.Done()

for {
if nextPackage, canContinue := channels.Receive(context.TODO(), jobC); !canContinue {
break
} else if hasGenerationDirectives, err := packageHasGenerationDirectives(nextPackage.Dir); err != nil {
addErr(err)
} else if hasGenerationDirectives {
var (
command = "go"
args = []string{"generate", nextPackage.Dir}
)

if err := cmdrunner.Run(command, args, nextPackage.Dir, env); err != nil {
addErr(err)
}
}
}
}()
}
}

func moduleGenerate(modPath string, env environment.Environment) error {
// WorkspaceGenerate runs go generate ./... for all module paths passed
func WorkspaceGenerate(modPaths []string, env environment.Environment) error {
var (
errs []error
wg sync.WaitGroup
mu sync.Mutex
)
errs []error
errsLock = &sync.Mutex{}
addErr = func(err error) {
errsLock.Lock()
defer errsLock.Unlock()

if packages, err := moduleListPackages(modPath); err != nil {
return fmt.Errorf("listing packages for module %s: %w", modPath, err)
} else {
for _, pkg := range packages {
wg.Add(1)
go func(pkg GoPackage) {
defer wg.Done()

var (
command = "go"
args = []string{"generate", pkg.Dir}
)

if err := cmdrunner.Run(command, args, pkg.Dir, env); err != nil {
mu.Lock()
errs = append(errs, fmt.Errorf("generate code for package %s: %w", pkg, err))
mu.Unlock()
}
}(pkg)
errs = append(errs, err)
}

wg.Wait()
jobC = make(chan GoPackage)
waitGroup = &sync.WaitGroup{}
)

// Start the parallel workers first
go parallelGenerateModulePackages(jobC, waitGroup, env, addErr)

return errors.Join(errs...)
// For each known module path attempt generation of each module package
for _, modPath := range modPaths {
if modulePackages, err := moduleListPackages(modPath); err != nil {
return fmt.Errorf("listing packages for module %s: %w", modPath, err)
} else {
for _, modulePackage := range modulePackages {
if !channels.Submit(context.Background(), jobC, modulePackage) {
return fmt.Errorf("canceled")
}
}
}
}

return errors.Join(errs...)
}
4 changes: 2 additions & 2 deletions packages/go/stbernard/workspace/golang/mod.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func moduleListPackages(modPath string) ([]GoPackage, error) {
}
packages = append(packages, p)
}
cmd.Wait()
return packages, nil

return packages, cmd.Wait()
}
}
Loading
0