@@ -0,0 +1 @@ | |||
/ssh-ident |
@@ -0,0 +1,176 @@ | |||
package main | |||
import ( | |||
"fmt" | |||
"io/ioutil" | |||
"os" | |||
"os/exec" | |||
"path" | |||
"path/filepath" | |||
"regexp" | |||
"sort" | |||
"strconv" | |||
"strings" | |||
"syscall" | |||
) | |||
type Agent struct { | |||
Identity Identity | |||
Path string | |||
Env []string | |||
} | |||
func (a *Agent) getEnv() []string { | |||
env := os.Environ() | |||
env = append(env, a.Env...) | |||
return env | |||
} | |||
func (a *Agent) loadEnv() { | |||
d, err := ioutil.ReadFile(a.envFile()) | |||
fatal(err) | |||
properties := string(d) | |||
env := extractEnv(properties) | |||
a.Env = env | |||
} | |||
func (a *Agent) loadKeys() { | |||
path := path.Join(a.Identity.Path, "id_*") | |||
keys, err := filepath.Glob(path) | |||
fatal(err) | |||
sort.Strings(keys) | |||
fmt.Fprintf(os.Stderr, "Load private key:\n") | |||
privateKeys := []string{} | |||
for _, key := range keys { | |||
if strings.HasSuffix(key, ".pub") { | |||
continue | |||
} | |||
fmt.Fprintf(os.Stderr, " %s\n", key) | |||
privateKeys = append(privateKeys, key) | |||
} | |||
cmd := exec.Command("ssh-add", privateKeys...) | |||
cmd.Env = a.getEnv() | |||
_, e, err := capture3(cmd) | |||
if err != nil { | |||
os.Stderr.Write(e) | |||
fatal(err) | |||
} | |||
} | |||
func (a *Agent) envFile() string { | |||
return a.Path + ".env" | |||
} | |||
var propertyRegex = regexp.MustCompile(`([^=]+)=([^;]+); export .*;`) | |||
func extractEnv(str string) []string { | |||
properties := strings.Split(str, "\n") | |||
env := []string{} | |||
for _, property := range properties { | |||
match := propertyRegex.FindStringSubmatch(property) | |||
if match != nil { | |||
name := match[1] | |||
value := match[2] | |||
property = fmt.Sprintf("%s=%s", name, value) | |||
env = append(env, property) | |||
} | |||
} | |||
return env | |||
} | |||
func (a *Agent) start() { | |||
fmt.Fprintf(os.Stderr, "Start new agent for identity %s\n", a.Identity.Name) | |||
sock := a.Path + ".sock" | |||
cmd := exec.Command("ssh-agent", "-a", sock) | |||
o, e, err := capture3(cmd) | |||
if err != nil { | |||
os.Stderr.Write(e) | |||
fatal(err) | |||
} | |||
properties := string(o) | |||
env := extractEnv(properties) | |||
a.Env = env | |||
err = ioutil.WriteFile(a.envFile(), o, 0600) | |||
fatal(err) | |||
a.loadKeys() | |||
} | |||
func (a *Agent) getPid() int { | |||
d, err := ioutil.ReadFile(a.envFile()) | |||
fatal(err) | |||
properties := string(d) | |||
env := strings.Split(properties, "\n") | |||
for _, property := range env { | |||
match := propertyRegex.FindStringSubmatch(property) | |||
if match != nil { | |||
name := match[1] | |||
if name == "SSH_AGENT_PID" { | |||
value := match[2] | |||
pid, err := strconv.Atoi(value) | |||
fatal(err) | |||
return pid | |||
} | |||
} | |||
} | |||
return -1 | |||
} | |||
func (a *Agent) init() { | |||
if _, err := os.Stat(a.envFile()); os.IsNotExist(err) { | |||
a.start() | |||
return | |||
} | |||
pid := a.getPid() | |||
if pid <= 0 { | |||
a.start() | |||
return | |||
} | |||
proc, err := os.FindProcess(pid) | |||
if err == nil { | |||
err = proc.Signal(syscall.Signal(0)) | |||
if err != nil { | |||
a.start() | |||
return | |||
} | |||
} | |||
a.loadEnv() | |||
} | |||
func NewAgent(config Config, identity Identity) Agent { | |||
p := path.Join(config.AgentsDir, identity.Name) | |||
agent := Agent{ | |||
Identity: identity, | |||
Path: p, | |||
} | |||
agent.init() | |||
return agent | |||
} | |||
func (a *Agent) Run(config Config, prog string, args []string) { | |||
identity := a.Identity | |||
fmt.Fprintf(os.Stderr, "\033[1;41m[%s]\033[0m %s %s\n", identity.Name, prog, strings.Join(args, " ")) | |||
exe := path.Join(config.BinDir, prog) | |||
if _, err := os.Stat(exe); os.IsNotExist(err) { | |||
fatal(fmt.Errorf("%s: no such file or directory", exe)) | |||
} | |||
sshConfig := path.Join(identity.Path, "config") | |||
_, err := os.Stat(exe) | |||
if err == nil { | |||
args = append([]string{"-F", sshConfig}, args...) | |||
} | |||
args = append([]string{prog}, args...) | |||
env := a.getEnv() | |||
syscall.Exec(exe, args, env) | |||
} |
@@ -0,0 +1,3 @@ | |||
module imirhil.fr/ssh-ident | |||
require gopkg.in/yaml.v2 v2.2.1 |
@@ -0,0 +1,3 @@ | |||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= | |||
gopkg.in/yaml.v2 v2.2.1 h1:mUhvW9EsL+naU5Q3cakzfE91YhliOondGd6ZrsDBHQE= | |||
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= |
@@ -0,0 +1,72 @@ | |||
package main | |||
import ( | |||
"os" | |||
"path" | |||
"regexp" | |||
"strings" | |||
) | |||
type Identity struct { | |||
Name string | |||
Path string | |||
} | |||
func newIdentity(config Config, name string) Identity { | |||
return Identity{ | |||
Name: name, | |||
Path: "", | |||
} | |||
} | |||
var hostRegexps = map[string]*regexp.Regexp{ | |||
"ssh": regexp.MustCompile("(?P<user>.*@)?(?P<host>.*)"), | |||
"scp": regexp.MustCompile("(?P<user>.*@)?(?P<host>.*):(?P<file>.*)"), | |||
} | |||
func extractHost(prog string, args []string) string { | |||
args = removeOptions(prog, args) | |||
switch prog { | |||
case "ssh": | |||
if len(args) == 0 { | |||
return "" | |||
} | |||
re := hostRegexps["ssh"] | |||
arg := args[0] | |||
p := matchParams(re, arg) | |||
return p["host"] | |||
case "scp": | |||
re := hostRegexps["scp"] | |||
for _, arg := range args { | |||
if p := matchParams(re, arg); p != nil { | |||
return p["host"] | |||
} | |||
} | |||
} | |||
return "" | |||
} | |||
func findIdentityName(config Config, prog string, args []string) string { | |||
name := os.Getenv("SSH_IDENTITY") | |||
if name != "" { | |||
return name | |||
} | |||
identities := config.Identities | |||
host := extractHost(prog, args) | |||
for match, name := range identities { | |||
if strings.Contains(host, match) { | |||
return name | |||
} | |||
} | |||
return config.DefaultIdentity | |||
} | |||
func FindIdentity(config Config, prog string, args []string) Identity { | |||
name := findIdentityName(config, prog, args) | |||
path := path.Join(config.IdentitiesDir, name) | |||
return Identity{ | |||
Name: name, | |||
Path: path, | |||
} | |||
} |
@@ -0,0 +1,79 @@ | |||
package main | |||
import ( | |||
"io/ioutil" | |||
"os" | |||
"os/user" | |||
"path" | |||
"strings" | |||
"gopkg.in/yaml.v2" | |||
) | |||
type Config struct { | |||
BinDir string `yaml:"bin_dir"` | |||
AgentsDir string | |||
IdentitiesDir string | |||
DefaultIdentity string `yaml:"default_identity"` | |||
Identities map[string]string | |||
} | |||
var sshOptions = map[string]map[string]string{ | |||
"ssh": { | |||
"short": "1246AaconfiggKkMNnqsTtVvXxYy", | |||
"long": "bcDEeFIiJLlmOopQRSWw", | |||
}, | |||
"scp": { | |||
"short": "12346BCpqrv", | |||
"long": "cFiloPS", | |||
}, | |||
} | |||
func removeOptions(prog string, args []string) []string { | |||
notOptions := []string{} | |||
longOptions := sshOptions[prog]["long"] | |||
long := false | |||
for _, arg := range args { | |||
if long { | |||
long = false | |||
continue | |||
} else if strings.HasPrefix(arg, "-") { | |||
last := arg[len(arg)-1:] | |||
long = strings.Contains(longOptions, last) | |||
} else { | |||
notOptions = append(notOptions, arg) | |||
} | |||
} | |||
return notOptions | |||
} | |||
func main() { | |||
usr, err := user.Current() | |||
fatal(err) | |||
home := usr.HomeDir | |||
sshDir := path.Join(home, ".ssh") | |||
identitiesDir := path.Join(sshDir, "identities") | |||
configFile := path.Join(identitiesDir, "config.yml") | |||
data, err := ioutil.ReadFile(configFile) | |||
fatal(err) | |||
config := Config{} | |||
err = yaml.Unmarshal(data, &config) | |||
fatal(err) | |||
config.IdentitiesDir = identitiesDir | |||
prog := os.Getenv("SSH_BINARY") | |||
if prog == "" { | |||
prog = os.Args[0] | |||
prog = path.Base(prog) | |||
} | |||
args := os.Args[1:] | |||
identity := FindIdentity(config, prog, args) | |||
agentsDir := path.Join(sshDir, "agents") | |||
config.AgentsDir = agentsDir | |||
agent := NewAgent(config, identity) | |||
agent.Run(config, prog, args) | |||
} |
@@ -0,0 +1,45 @@ | |||
package main | |||
import ( | |||
"bytes" | |||
"fmt" | |||
"os" | |||
"os/exec" | |||
"regexp" | |||
) | |||
func fatal(err error) { | |||
if err == nil { | |||
return | |||
} | |||
fmt.Fprintf(os.Stderr, "%s\n", err.Error()) | |||
os.Exit(-1) | |||
} | |||
func matchParams(re *regexp.Regexp, str string) map[string]string { | |||
match := re.FindStringSubmatch(str) | |||
if match == nil { | |||
return nil | |||
} | |||
params := make(map[string]string) | |||
for i, name := range re.SubexpNames() { | |||
params[name] = match[i] | |||
} | |||
return params | |||
} | |||
func capture3(cmd *exec.Cmd) ([]byte, []byte, error) { | |||
var stdout, stderr bytes.Buffer | |||
cmd.Stdout = &stdout | |||
cmd.Stderr = &stderr | |||
err := cmd.Run() | |||
return stdout.Bytes(), stderr.Bytes(), err | |||
} | |||
func dumpEnv(env []string) { | |||
for _, e := range env { | |||
fmt.Println(e) | |||
} | |||
} |