diff --git a/cmd/which.go b/cmd/which.go index b66eafc..647d809 100644 --- a/cmd/which.go +++ b/cmd/which.go @@ -9,9 +9,28 @@ import ( var ErrNotFound = errors.New("which: command not found") -func Which(cmd string) (dir string, err error) { +type WhichOptions struct { + Env map[string]string + Sources []string + Cwd string +} - command := exec.Command("which", cmd) +func Which(cmd string, options WhichOptions) (dir string, err error) { + + var sourceCommand strings.Builder + for _, value := range options.Sources { + sourceCommand.WriteString(fmt.Sprintf("source %s && ", value)) + } + + command := exec.Command(sourceCommand.String(), "which", cmd) + + if options.Cwd != "" { + command.Dir = options.Cwd + } + + for k, v := range options.Env { + command.Env = append(command.Env, fmt.Sprintf("%s=%s", k, v)) + } outputBytes, err := command.Output() if err != nil { diff --git a/linux/run.go b/linux/run.go index e890eb7..f83aa10 100644 --- a/linux/run.go +++ b/linux/run.go @@ -160,11 +160,6 @@ func (cmd *LinuxCommand) Run() error { } }() - // Loop through env to format and add them to the command. - for key, value := range cmd.Options.Env { - command.Env = append(command.Env, fmt.Sprintf("%s=%s", key, value)) - } - isCommandExecutable, err := cmd.isCommandExecutable(cmd.Options.Command) if err != nil { return err diff --git a/linux/utils.go b/linux/utils.go index 7146b29..db2a126 100644 --- a/linux/utils.go +++ b/linux/utils.go @@ -12,7 +12,11 @@ import ( func (cmd *LinuxCommand) isCommandExecutable(command string) (bool, error) { - whichOut, err := cmd2.Which(command) + whichOut, err := cmd2.Which(command, cmd2.WhichOptions{ + Env: cmd.Options.Env, + Sources: cmd.Options.Sources, + Cwd: cmd.Options.Cwd, + }) if err != nil { if errors.Is(err, cmd2.ErrNotFound) { if _, err := os.Stat(command); errors.Is(err, fs.ErrNotExist) {