ref: b8d3eb98ea92d0420feef324eb54a099465707e3
dir: /sys/src/cmd/ssh.c/
#include <u.h> #include <libc.h> #include <mp.h> #include <libsec.h> #include <auth.h> #include <authsrv.h> enum { MSG_DISCONNECT = 1, MSG_IGNORE, MSG_UNIMPLEMENTED, MSG_DEBUG, MSG_SERVICE_REQUEST, MSG_SERVICE_ACCEPT, MSG_KEXINIT = 20, MSG_NEWKEYS, MSG_ECDH_INIT = 30, MSG_ECDH_REPLY, MSG_USERAUTH_REQUEST = 50, MSG_USERAUTH_FAILURE, MSG_USERAUTH_SUCCESS, MSG_USERAUTH_BANNER, MSG_USERAUTH_PK_OK = 60, MSG_USERAUTH_INFO_REQUEST = 60, MSG_USERAUTH_INFO_RESPONSE = 61, MSG_GLOBAL_REQUEST = 80, MSG_REQUEST_SUCCESS, MSG_REQUEST_FAILURE, MSG_CHANNEL_OPEN = 90, MSG_CHANNEL_OPEN_CONFIRMATION, MSG_CHANNEL_OPEN_FAILURE, MSG_CHANNEL_WINDOW_ADJUST, MSG_CHANNEL_DATA, MSG_CHANNEL_EXTENDED_DATA, MSG_CHANNEL_EOF, MSG_CHANNEL_CLOSE, MSG_CHANNEL_REQUEST, MSG_CHANNEL_SUCCESS, MSG_CHANNEL_FAILURE, }; enum { Overhead = 256, // enougth for MSG_CHANNEL_DATA header MaxPacket = 1<<15, WinPackets = 8, // (1<<15) * 8 = 256K }; int MaxPwTries = 3; // retry this often for keyboard-interactive typedef struct { u32int seq; u32int kex; u32int chan; int win; int pkt; int eof; Chachastate cs1; Chachastate cs2; uchar *r; uchar *w; uchar b[Overhead + MaxPacket]; char *v; int pid; Rendez; } Oneway; int nsid; uchar sid[256]; char thumb[2*SHA2_256dlen+1], *thumbfile; int fd, intr, raw, port, mux, debug; char *user, *service, *status, *host, *remote, *cmd; Oneway recv, send; void dispatch(void); void shutdown(void) { recv.eof = send.eof = 1; if(send.pid > 0) postnote(PNPROC, send.pid, "shutdown"); } void catch(void*, char *msg) { if(strcmp(msg, "interrupt") == 0){ intr = 1; noted(NCONT); } noted(NDFLT); } int wasintr(void) { char err[ERRMAX]; int r; memset(err, 0, sizeof(err)); errstr(err, sizeof(err)); r = strcmp(err, "interrupted") == 0; errstr(err, sizeof(err)); return r; } #define PUT4(p, u) (p)[0] = (u)>>24, (p)[1] = (u)>>16, (p)[2] = (u)>>8, (p)[3] = (u) #define GET4(p) (u32int)(p)[3] | (u32int)(p)[2]<<8 | (u32int)(p)[1]<<16 | (u32int)(p)[0]<<24 int vpack(uchar *p, int n, char *fmt, va_list a) { uchar *p0 = p, *e = p+n; u32int u; mpint *m; void *s; int c; for(;;){ switch(c = *fmt++){ case '\0': return p - p0; case '_': if(++p > e) goto err; break; case '.': *va_arg(a, void**) = p; break; case 'b': if(p >= e) goto err; *p++ = va_arg(a, int); break; case 'm': m = va_arg(a, mpint*); u = (mpsignif(m)+8)/8; if(p+4 > e) goto err; PUT4(p, u), p += 4; if(u > e-p) goto err; mptober(m, p, u), p += u; break; case '[': case 's': s = va_arg(a, void*); u = va_arg(a, int); if(c == 's'){ if(p+4 > e) goto err; PUT4(p, u), p += 4; } if(u > e-p) goto err; memmove(p, s, u); p += u; break; case 'u': u = va_arg(a, int); if(p+4 > e) goto err; PUT4(p, u), p += 4; break; } } err: return -1; } int vunpack(uchar *p, int n, char *fmt, va_list a) { uchar *p0 = p, *e = p+n; u32int u; mpint *m; void *s; for(;;){ switch(*fmt++){ case '\0': return p - p0; case '_': if(++p > e) goto err; break; case '.': *va_arg(a, void**) = p; break; case 'b': if(p >= e) goto err; *va_arg(a, int*) = *p++; break; case 'm': if(p+4 > e) goto err; u = GET4(p), p += 4; if(u > e-p) goto err; m = va_arg(a, mpint*); betomp(p, u, m), p += u; break; case 's': if(p+4 > e) goto err; u = GET4(p), p += 4; if(u > e-p) goto err; *va_arg(a, void**) = p; *va_arg(a, int*) = u; p += u; break; case '[': s = va_arg(a, void*); u = va_arg(a, int); if(u > e-p) goto err; memmove(s, p, u); p += u; break; case 'u': if(p+4 > e) goto err; u = GET4(p); *va_arg(a, int*) = u; p += 4; break; } } err: return -1; } int pack(uchar *p, int n, char *fmt, ...) { va_list a; va_start(a, fmt); n = vpack(p, n, fmt, a); va_end(a); return n; } int unpack(uchar *p, int n, char *fmt, ...) { va_list a; va_start(a, fmt); n = vunpack(p, n, fmt, a); va_end(a); return n; } void setupcs(Oneway *c, uchar otk[32]) { uchar iv[8]; memset(otk, 0, 32); pack(iv, sizeof(iv), "uu", 0, c->seq); chacha_setiv(&c->cs1, iv); chacha_setiv(&c->cs2, iv); chacha_setblock(&c->cs1, 0); chacha_setblock(&c->cs2, 0); chacha_encrypt(otk, 32, &c->cs2); } void sendpkt(char *fmt, ...) { static uchar buf[sizeof(send.b)]; int n, pad; va_list a; va_start(a, fmt); n = vpack(send.b, sizeof(send.b), fmt, a); va_end(a); if(n < 0) { toobig: sysfatal("sendpkt: message too big"); return; } send.r = send.b; send.w = send.b+n; if(debug > 1) fprint(2, "sendpkt: (%d) %.*H\n", send.r[0], (int)(send.w-send.r), send.r); if(nsid){ /* undocumented */ pad = ChachaBsize - ((5+n) % ChachaBsize) + 4; } else { for(pad=4; (5+n+pad) % 8; pad++) ; } prng(send.w, pad); n = pack(buf, sizeof(buf)-16, "ub[[", 1+n+pad, pad, send.b, n, send.w, pad); if(n < 0) goto toobig; if(nsid){ uchar otk[32]; setupcs(&send, otk); chacha_encrypt(buf, 4, &send.cs1); chacha_encrypt(buf+4, n-4, &send.cs2); poly1305(buf, n, otk, sizeof(otk), buf+n, nil); n += 16; } if(write(fd, buf, n) != n) sysfatal("write: %r"); send.seq++; } int readall(int fd, uchar *data, int len) { int n, tot; for(tot = 0; tot < len; tot += n){ n = read(fd, data+tot, len-tot); if(n <= 0){ if(n < 0 && wasintr()){ n = 0; continue; } else if(n == 0) werrstr("eof"); break; } } return tot; } int recvpkt(void) { uchar otk[32], tag[16]; DigestState *ds = nil; int n; if(readall(fd, recv.b, 4) != 4) sysfatal("read1: %r"); if(nsid){ setupcs(&recv, otk); ds = poly1305(recv.b, 4, otk, sizeof(otk), nil, nil); chacha_encrypt(recv.b, 4, &recv.cs1); unpack(recv.b, 4, "u", &n); n += 16; } else { unpack(recv.b, 4, "u", &n); } if(n < 8 || n > sizeof(recv.b)){ badlen: sysfatal("bad length %d", n); } if(readall(fd, recv.b, n) != n) sysfatal("read2: %r"); if(nsid){ n -= 16; if(n < 0) goto badlen; poly1305(recv.b, n, otk, sizeof(otk), tag, ds); if(tsmemcmp(tag, recv.b+n, 16) != 0) sysfatal("bad tag"); chacha_encrypt(recv.b, n, &recv.cs2); } n -= recv.b[0]+1; if(n < 1) goto badlen; recv.r = recv.b + 1; recv.w = recv.r + n; recv.seq++; if(debug > 1) fprint(2, "recvpkt: (%d) %.*H\n", recv.r[0], (int)(recv.w-recv.r), recv.r); return recv.r[0]; } static char sshrsa[] = "ssh-rsa"; int rsapub2ssh(RSApub *rsa, uchar *data, int len) { return pack(data, len, "smm", sshrsa, sizeof(sshrsa)-1, rsa->ek, rsa->n); } RSApub* ssh2rsapub(uchar *data, int len) { RSApub *pub; char *s; int n; pub = rsapuballoc(); pub->n = mpnew(0); pub->ek = mpnew(0); if(unpack(data, len, "smm", &s, &n, pub->ek, pub->n) < 0 || n != sizeof(sshrsa)-1 || memcmp(s, sshrsa, n) != 0){ rsapubfree(pub); return nil; } return pub; } int rsasig2ssh(RSApub *pub, mpint *S, uchar *data, int len) { int l = (mpsignif(pub->n)+7)/8; if(4+7+4+l > len) return -1; mptober(S, data+4+7+4, l); return pack(data, len, "ss", sshrsa, sizeof(sshrsa)-1, data+4+7+4, l); } mpint* ssh2rsasig(uchar *data, int len) { mpint *m; char *s; int n; m = mpnew(0); if(unpack(data, len, "sm", &s, &n, m) < 0 || n != sizeof(sshrsa)-1 || memcmp(s, sshrsa, n) != 0){ mpfree(m); return nil; } return m; } mpint* pkcs1digest(uchar *data, int len, RSApub *pub) { uchar digest[SHA1dlen], buf[256]; sha1(data, len, digest, nil); return pkcs1padbuf(buf, asn1encodedigest(sha1, digest, buf, sizeof(buf)), pub->n, 1); } int pkcs1verify(uchar *data, int len, RSApub *pub, mpint *S) { mpint *V; int ret; V = pkcs1digest(data, len, pub); ret = V != nil; if(ret){ rsaencrypt(pub, S, S); ret = mpcmp(V, S) == 0; mpfree(V); } return ret; } DigestState* hashstr(void *data, ulong len, DigestState *ds) { uchar l[4]; pack(l, 4, "u", len); return sha2_256((uchar*)data, len, nil, sha2_256(l, 4, nil, ds)); } void kdf(uchar *k, int nk, uchar *h, char x, uchar *out, int len) { uchar digest[SHA2_256dlen], *out0; DigestState *ds; int n; ds = hashstr(k, nk, nil); ds = sha2_256(h, sizeof(digest), nil, ds); ds = sha2_256((uchar*)&x, 1, nil, ds); sha2_256(sid, nsid, digest, ds); for(out0=out;;){ n = len; if(n > sizeof(digest)) n = sizeof(digest); memmove(out, digest, n); len -= n; if(len == 0) break; out += n; ds = hashstr(k, nk, nil); ds = sha2_256(h, sizeof(digest), nil, ds); sha2_256(out0, out-out0, digest, ds); } } void kex(int gotkexinit) { static char kexalgs[] = "curve25519-sha256,curve25519-sha256@libssh.org"; static char cipheralgs[] = "chacha20-poly1305@openssh.com"; static char zipalgs[] = "none"; static char macalgs[] = "hmac-sha1"; /* work around for github.com */ static char langs[] = ""; uchar cookie[16], x[32], yc[32], z[32], k[32+1], h[SHA2_256dlen], *ys, *ks, *sig; uchar k12[2*ChachaKeylen]; int i, nk, nys, nks, nsig; DigestState *ds; mpint *S, *K; RSApub *pub; ds = hashstr(send.v, strlen(send.v), nil); ds = hashstr(recv.v, strlen(recv.v), ds); genrandom(cookie, sizeof(cookie)); sendpkt("b[ssssssssssbu", MSG_KEXINIT, cookie, sizeof(cookie), kexalgs, sizeof(kexalgs)-1, sshrsa, sizeof(sshrsa)-1, cipheralgs, sizeof(cipheralgs)-1, cipheralgs, sizeof(cipheralgs)-1, macalgs, sizeof(macalgs)-1, macalgs, sizeof(macalgs)-1, zipalgs, sizeof(zipalgs)-1, zipalgs, sizeof(zipalgs)-1, langs, sizeof(langs)-1, langs, sizeof(langs)-1, 0, 0); ds = hashstr(send.r, send.w-send.r, ds); if(!gotkexinit){ Next0: switch(recvpkt()){ default: dispatch(); goto Next0; case MSG_KEXINIT: break; } } ds = hashstr(recv.r, recv.w-recv.r, ds); if(debug){ char *tab[] = { "kexalgs", "hostalgs", "cipher1", "cipher2", "mac1", "mac2", "zip1", "zip2", "lang1", "lang2", nil, }, **t, *s; uchar *p = recv.r+17; int n; for(t=tab; *t != nil; t++){ if(unpack(p, recv.w-p, "s.", &s, &n, &p) < 0) break; fprint(2, "%s: %.*s\n", *t, utfnlen(s, n), s); } } curve25519_dh_new(x, yc); yc[31] &= ~0x80; sendpkt("bs", MSG_ECDH_INIT, yc, sizeof(yc)); Next1: switch(recvpkt()){ default: dispatch(); goto Next1; case MSG_KEXINIT: sysfatal("inception"); case MSG_ECDH_REPLY: if(unpack(recv.r, recv.w-recv.r, "_sss", &ks, &nks, &ys, &nys, &sig, &nsig) < 0) sysfatal("bad ECDH_REPLY"); break; } if(nys != 32) sysfatal("bad server ECDH ephermal public key length"); ds = hashstr(ks, nks, ds); ds = hashstr(yc, 32, ds); ds = hashstr(ys, 32, ds); if(thumb[0] == 0){ Thumbprint *ok; sha2_256(ks, nks, h, nil); i = enc64(thumb, sizeof(thumb), h, sizeof(h)); while(i > 0 && thumb[i-1] == '=') i--; thumb[i] = '\0'; if(debug) fprint(2, "host fingerprint: %s\n", thumb); ok = initThumbprints(thumbfile, nil, "ssh"); if(ok == nil || !okThumbprint(h, sizeof(h), ok)){ if(ok != nil) werrstr("unknown host"); fprint(2, "%s: %r\n", argv0); fprint(2, "verify hostkey: %s %.*[\n", sshrsa, nks, ks); fprint(2, "add thumbprint after verification:\n"); fprint(2, "\techo 'ssh sha256=%s server=%s' >> %q\n", thumb, host, thumbfile); sysfatal("checking hostkey failed: %r"); } freeThumbprints(ok); } if((pub = ssh2rsapub(ks, nks)) == nil) sysfatal("bad server public key"); if((S = ssh2rsasig(sig, nsig)) == nil) sysfatal("bad server signature"); if(!curve25519_dh_finish(x, ys, z)) sysfatal("unlucky shared key"); K = betomp(z, 32, nil); nk = (mpsignif(K)+8)/8; mptober(K, k, nk); mpfree(K); ds = hashstr(k, nk, ds); sha2_256(nil, 0, h, ds); if(!pkcs1verify(h, sizeof(h), pub, S)) sysfatal("server verification failed"); mpfree(S); rsapubfree(pub); sendpkt("b", MSG_NEWKEYS); Next2: switch(recvpkt()){ default: dispatch(); goto Next2; case MSG_KEXINIT: sysfatal("inception"); case MSG_NEWKEYS: break; } /* next key exchange */ recv.kex = recv.seq + 100000; send.kex = send.seq + 100000; if(nsid == 0) memmove(sid, h, nsid = sizeof(h)); kdf(k, nk, h, 'C', k12, sizeof(k12)); setupChachastate(&send.cs1, k12+1*ChachaKeylen, ChachaKeylen, nil, 64/8, 20); setupChachastate(&send.cs2, k12+0*ChachaKeylen, ChachaKeylen, nil, 64/8, 20); kdf(k, nk, h, 'D', k12, sizeof(k12)); setupChachastate(&recv.cs1, k12+1*ChachaKeylen, ChachaKeylen, nil, 64/8, 20); setupChachastate(&recv.cs2, k12+0*ChachaKeylen, ChachaKeylen, nil, 64/8, 20); } static char *authnext; int authok(char *meth) { int ok = authnext == nil || strstr(authnext, meth) != nil; if(debug) fprint(2, "userauth %s %s\n", meth, ok ? "ok" : "skipped"); return ok; } int authfailure(char *meth) { char *s; int n, partial; if(unpack(recv.r, recv.w-recv.r, "_sb", &s, &n, &partial) < 0) sysfatal("bad auth failure response"); free(authnext); authnext = smprint("%.*s", utfnlen(s, n), s); if(debug) fprint(2, "userauth %s failed: partial=%d, next=%s\n", meth, partial, authnext); return partial != 0 || !authok(meth); } int noneauth(void) { static char authmeth[] = "none"; if(!authok(authmeth)) return -1; sendpkt("bsss", MSG_USERAUTH_REQUEST, user, strlen(user), service, strlen(service), authmeth, sizeof(authmeth)-1); Next0: switch(recvpkt()){ default: dispatch(); goto Next0; case MSG_USERAUTH_FAILURE: werrstr("authentication needed"); authfailure(authmeth); return -1; case MSG_USERAUTH_SUCCESS: return 0; } } int pubkeyauth(void) { static char authmeth[] = "publickey"; uchar pk[4096], sig[4096]; int npk, nsig; int afd, n; char *s; mpint *S; AuthRpc *rpc; RSApub *pub; if(!authok(authmeth)) return -1; if((afd = open("/mnt/factotum/rpc", ORDWR)) < 0) return -1; if((rpc = auth_allocrpc(afd)) == nil){ close(afd); return -1; } s = "proto=rsa service=ssh role=client"; if(auth_rpc(rpc, "start", s, strlen(s)) != ARok){ auth_freerpc(rpc); close(afd); return -1; } pub = rsapuballoc(); pub->n = mpnew(0); pub->ek = mpnew(0); while(auth_rpc(rpc, "read", nil, 0) == ARok){ s = rpc->arg; if(strtomp(s, &s, 16, pub->n) == nil) break; if(*s++ != ' ') continue; if(strtomp(s, nil, 16, pub->ek) == nil) continue; npk = rsapub2ssh(pub, pk, sizeof(pk)); sendpkt("bsssbss", MSG_USERAUTH_REQUEST, user, strlen(user), service, strlen(service), authmeth, sizeof(authmeth)-1, 0, sshrsa, sizeof(sshrsa)-1, pk, npk); Next1: switch(recvpkt()){ default: dispatch(); goto Next1; case MSG_USERAUTH_FAILURE: if(authfailure(authmeth)) goto Failed; continue; case MSG_USERAUTH_SUCCESS: case MSG_USERAUTH_PK_OK: break; } /* sign sid and the userauth request */ n = pack(send.b, sizeof(send.b), "sbsssbss", sid, nsid, MSG_USERAUTH_REQUEST, user, strlen(user), service, strlen(service), authmeth, sizeof(authmeth)-1, 1, sshrsa, sizeof(sshrsa)-1, pk, npk); S = pkcs1digest(send.b, n, pub); n = snprint((char*)send.b, sizeof(send.b), "%B", S); mpfree(S); if(auth_rpc(rpc, "write", (char*)send.b, n) != ARok) break; if(auth_rpc(rpc, "read", nil, 0) != ARok) break; S = strtomp(rpc->arg, nil, 16, nil); nsig = rsasig2ssh(pub, S, sig, sizeof(sig)); mpfree(S); /* send final userauth request with the signature */ sendpkt("bsssbsss", MSG_USERAUTH_REQUEST, user, strlen(user), service, strlen(service), authmeth, sizeof(authmeth)-1, 1, sshrsa, sizeof(sshrsa)-1, pk, npk, sig, nsig); Next2: switch(recvpkt()){ default: dispatch(); goto Next2; case MSG_USERAUTH_FAILURE: if(authfailure(authmeth)) goto Failed; continue; case MSG_USERAUTH_SUCCESS: break; } rsapubfree(pub); auth_freerpc(rpc); close(afd); return 0; } Failed: rsapubfree(pub); auth_freerpc(rpc); close(afd); return -1; } int passauth(void) { static char authmeth[] = "password"; UserPasswd *up; if(!authok(authmeth)) return -1; up = auth_getuserpasswd(auth_getkey, "proto=pass service=ssh user=%q server=%q thumb=%q", user, host, thumb); if(up == nil) return -1; sendpkt("bsssbs", MSG_USERAUTH_REQUEST, user, strlen(user), service, strlen(service), authmeth, sizeof(authmeth)-1, 0, up->passwd, strlen(up->passwd)); memset(up->passwd, 0, strlen(up->passwd)); free(up); Next0: switch(recvpkt()){ default: dispatch(); goto Next0; case MSG_USERAUTH_FAILURE: werrstr("wrong password"); authfailure(authmeth); return -1; case MSG_USERAUTH_SUCCESS: return 0; } } int kbintauth(void) { static char authmeth[] = "keyboard-interactive"; int tries; char *name, *inst, *s, *a; int fd, i, n, m; int nquest, echo; uchar *ans, *answ; tries = 0; if(!authok(authmeth)) return -1; Loop: if(++tries > MaxPwTries) return -1; sendpkt("bsssss", MSG_USERAUTH_REQUEST, user, strlen(user), service, strlen(service), authmeth, sizeof(authmeth)-1, "", 0, "", 0); Next0: switch(recvpkt()){ default: dispatch(); goto Next0; case MSG_USERAUTH_FAILURE: werrstr("keyboard-interactive failed"); if(authfailure(authmeth)) return -1; goto Loop; case MSG_USERAUTH_SUCCESS: return 0; case MSG_USERAUTH_INFO_REQUEST: break; } Retry: if((fd = open("/dev/cons", OWRITE)) < 0) return -1; if(unpack(recv.r, recv.w-recv.r, "_ss.", &name, &n, &inst, &m, &recv.r) < 0) sysfatal("bad info request: name, inst"); while(n > 0 && strchr("\r\n\t ", name[n-1]) != nil) n--; while(m > 0 && strchr("\r\n\t ", inst[m-1]) != nil) m--; if(n > 0) fprint(fd, "%.*s\n", utfnlen(name, n), name); if(m > 0) fprint(fd, "%.*s\n", utfnlen(inst, m), inst); /* lang, nprompt */ if(unpack(recv.r, recv.w-recv.r, "su.", &s, &n, &nquest, &recv.r) < 0) sysfatal("bad info request: lang, #quest"); ans = answ = nil; for(i = 0; i < nquest; i++){ if(unpack(recv.r, recv.w-recv.r, "sb.", &s, &n, &echo, &recv.r) < 0) sysfatal("bad info request: question [%d]", i); while(n > 0 && strchr("\r\n\t :", s[n-1]) != nil) n--; s[n] = '\0'; if((a = readcons(s, nil, !echo)) == nil) sysfatal("readcons: %r"); n = answ - ans; m = strlen(a)+4; if((s = realloc(ans, n + m)) == nil) sysfatal("realloc: %r"); ans = (uchar*)s; answ = ans+n; answ += pack(answ, m, "s", a, m-4); } sendpkt("bu[", MSG_USERAUTH_INFO_RESPONSE, i, ans, answ - ans); free(ans); close(fd); Next1: switch(recvpkt()){ default: dispatch(); goto Next1; case MSG_USERAUTH_INFO_REQUEST: goto Retry; case MSG_USERAUTH_FAILURE: werrstr("keyboard-interactive failed"); if(authfailure(authmeth)) return -1; goto Loop; case MSG_USERAUTH_SUCCESS: return 0; } } void dispatch(void) { char *s; uchar *p; int n, b, c; switch(recv.r[0]){ case MSG_IGNORE: return; case MSG_GLOBAL_REQUEST: if(unpack(recv.r, recv.w-recv.r, "_sb", &s, &n, &b) < 0) break; if(debug) fprint(2, "%s: global request: %.*s\n", argv0, utfnlen(s, n), s); if(b != 0) sendpkt("b", MSG_REQUEST_FAILURE); return; case MSG_DISCONNECT: if(unpack(recv.r, recv.w-recv.r, "_us", &c, &s, &n) < 0) break; sysfatal("disconnect: (%d) %.*s", c, utfnlen(s, n), s); return; case MSG_DEBUG: if(unpack(recv.r, recv.w-recv.r, "__sb", &s, &n, &c) < 0) break; if(c != 0 || debug) fprint(2, "%s: %.*s\n", argv0, utfnlen(s, n), s); return; case MSG_USERAUTH_BANNER: if(unpack(recv.r, recv.w-recv.r, "_s", &s, &n) < 0) break; if(raw) write(2, s, n); return; case MSG_KEXINIT: kex(1); return; } if(mux){ n = recv.w - recv.r; if(write(1, recv.r, n) != n) sysfatal("write out: %r"); return; } switch(recv.r[0]){ case MSG_CHANNEL_DATA: if(unpack(recv.r, recv.w-recv.r, "_us", &c, &s, &n) < 0) break; if(c != recv.chan) break; if(write(1, s, n) != n) sysfatal("write out: %r"); Winadjust: recv.win -= n; if(recv.win < recv.pkt){ n = WinPackets*recv.pkt; recv.win += n; sendpkt("buu", MSG_CHANNEL_WINDOW_ADJUST, send.chan, n); } return; case MSG_CHANNEL_EXTENDED_DATA: if(unpack(recv.r, recv.w-recv.r, "_uus", &c, &b, &s, &n) < 0) break; if(c != recv.chan) break; if(b == 1) write(2, s, n); goto Winadjust; case MSG_CHANNEL_WINDOW_ADJUST: if(unpack(recv.r, recv.w-recv.r, "_uu", &c, &n) < 0) break; if(c != recv.chan) break; send.win += n; if(send.win >= send.pkt) rwakeup(&send); return; case MSG_CHANNEL_REQUEST: if(unpack(recv.r, recv.w-recv.r, "_usb.", &c, &s, &n, &b, &p) < 0) break; if(c != recv.chan) break; if(n == 11 && memcmp(s, "exit-signal", n) == 0){ if(unpack(p, recv.w-p, "s", &s, &n) < 0) break; if(n != 0 && status == nil) status = smprint("%.*s", utfnlen(s, n), s); c = MSG_CHANNEL_SUCCESS; } else if(n == 11 && memcmp(s, "exit-status", n) == 0){ if(unpack(p, recv.w-p, "u", &n) < 0) break; if(n != 0 && status == nil) status = smprint("%d", n); c = MSG_CHANNEL_SUCCESS; } else { if(debug) fprint(2, "%s: channel request: %.*s\n", argv0, utfnlen(s, n), s); c = MSG_CHANNEL_FAILURE; } if(b != 0) sendpkt("bu", c, recv.chan); return; case MSG_CHANNEL_EOF: recv.eof = 1; if(!raw) write(1, "", 0); return; case MSG_CHANNEL_CLOSE: shutdown(); return; } sysfatal("got: %.*H", (int)(recv.w - recv.r), recv.r); } char* readline(void) { uchar *p; for(p = send.b; p < &send.b[sizeof(send.b)-1]; p++){ *p = '\0'; if(read(fd, p, 1) != 1 || *p == '\n') break; } while(p >= send.b && (*p == '\n' || *p == '\r')) *p-- = '\0'; return (char*)send.b; } static struct { char *term; int xpixels; int ypixels; int lines; int cols; int gen; } tty; int getdim(void) { char *s; int g; if(s = getenv("WINCH")){ g = atoi(s); if(tty.gen == g) return 0; tty.gen = g; free(s); } if(s = getenv("XPIXELS")){ tty.xpixels = atoi(s); free(s); } if(s = getenv("YPIXELS")){ tty.ypixels = atoi(s); free(s); } if(s = getenv("LINES")){ tty.lines = atoi(s); free(s); } if(s = getenv("COLS")){ tty.cols = atoi(s); free(s); } return 1; } void rawon(void) { int ctl; close(0); if(open("/dev/cons", OREAD) != 0) sysfatal("open: %r"); close(1); if(open("/dev/cons", OWRITE) != 1) sysfatal("open: %r"); dup(1, 2); if((ctl = open("/dev/consctl", OWRITE)) >= 0){ write(ctl, "rawon", 5); write(ctl, "winchon", 7); /* vt(1): interrupt note on window change */ } getdim(); } #pragma varargck type "k" char* kfmt(Fmt *f) { char *s, *p; int n; s = va_arg(f->args, char*); n = fmtstrcpy(f, "'"); while((p = strchr(s, '\'')) != nil){ *p = '\0'; n += fmtstrcpy(f, s); *p = '\''; n += fmtstrcpy(f, "'\\''"); s = p+1; } n += fmtstrcpy(f, s); n += fmtstrcpy(f, "'"); return n; } void usage(void) { fprint(2, "usage: %s [-dR] [-t thumbfile] [-T tries] [-u user] [-h] [user@]host [-W remote!port] [cmd args...]\n", argv0); exits("usage"); } void main(int argc, char *argv[]) { static QLock sl; int b, n, c; char *s; quotefmtinstall(); fmtinstall('B', mpfmt); fmtinstall('H', encodefmt); fmtinstall('[', encodefmt); fmtinstall('k', kfmt); tty.gen = -1; tty.term = getenv("TERM"); if(tty.term == nil) tty.term = ""; raw = *tty.term != 0; ARGBEGIN { case 'd': debug++; break; case 'W': remote = EARGF(usage()); s = strrchr(remote, '!'); if(s == nil) s = strrchr(remote, ':'); if(s == nil) usage(); *s++ = 0; port = atoi(s); raw = 0; break; case 'R': raw = 0; break; case 'r': raw = 2; /* bloody */ break; case 'u': user = EARGF(usage()); break; case 'h': host = EARGF(usage()); break; case 't': thumbfile = EARGF(usage()); break; case 'T': MaxPwTries = strtol(EARGF(usage()), &s, 0); if(*s != 0) usage(); break; case 'X': mux = 1; raw = 0; break; default: usage(); } ARGEND; if(host == nil){ if(argc == 0) usage(); host = *argv++; } if(user == nil){ s = strchr(host, '@'); if(s != nil){ *s++ = '\0'; user = host; host = s; } } for(cmd = nil; *argv != nil; argv++){ if(cmd == nil){ cmd = strdup(*argv); if(raw == 1) raw = 0; }else{ s = smprint("%s %k", cmd, *argv); free(cmd); cmd = s; } } if(remote != nil && cmd != nil) usage(); if((fd = dial(netmkaddr(host, nil, "ssh"), nil, nil, nil)) < 0) sysfatal("dial: %r"); send.v = "SSH-2.0-(9)"; fprint(fd, "%s\r\n", send.v); recv.v = readline(); if(debug) fprint(2, "server verison: %s\n", recv.v); if(strncmp("SSH-2.0-", recv.v, 8) != 0) sysfatal("bad server version: %s", recv.v); recv.v = strdup(recv.v); send.l = recv.l = &sl; if(user == nil) user = getuser(); if(thumbfile == nil) thumbfile = smprint("%s/lib/sshthumbs", getenv("home")); kex(0); sendpkt("bs", MSG_SERVICE_REQUEST, "ssh-userauth", 12); Next0: switch(recvpkt()){ default: dispatch(); goto Next0; case MSG_SERVICE_ACCEPT: break; } service = "ssh-connection"; if(noneauth() < 0 && pubkeyauth() < 0 && passauth() < 0 && kbintauth() < 0) sysfatal("auth: %r"); recv.pkt = send.pkt = MaxPacket; recv.win = send.win = WinPackets*recv.pkt; recv.chan = send.win = 0; if(mux) goto Mux; /* open hailing frequencies */ if(remote != nil){ NetConnInfo *nci = getnetconninfo(nil, fd); if(nci == nil) sysfatal("can't get netconninfo: %r"); sendpkt("bsuuususu", MSG_CHANNEL_OPEN, "direct-tcpip", 12, recv.chan, recv.win, recv.pkt, remote, strlen(remote), port, nci->laddr, strlen(nci->laddr), atoi(nci->lserv)); free(nci); } else { sendpkt("bsuuu", MSG_CHANNEL_OPEN, "session", 7, recv.chan, recv.win, recv.pkt); } Next1: switch(recvpkt()){ default: dispatch(); goto Next1; case MSG_CHANNEL_OPEN_FAILURE: if(unpack(recv.r, recv.w-recv.r, "_uus", &c, &b, &s, &n) < 0) n = strlen(s = "???"); sysfatal("channel open failure: (%d) %.*s", b, utfnlen(s, n), s); case MSG_CHANNEL_OPEN_CONFIRMATION: break; } if(unpack(recv.r, recv.w-recv.r, "_uuuu", &recv.chan, &send.chan, &send.win, &send.pkt) < 0) sysfatal("bad channel open confirmation"); if(send.pkt <= 0 || send.pkt > MaxPacket) send.pkt = MaxPacket; if(remote != nil) goto Mux; if(raw) { rawon(); sendpkt("busbsuuuus", MSG_CHANNEL_REQUEST, send.chan, "pty-req", 7, 0, tty.term, strlen(tty.term), tty.cols, tty.lines, tty.xpixels, tty.ypixels, "", 0); } if(cmd == nil){ sendpkt("busb", MSG_CHANNEL_REQUEST, send.chan, "shell", 5, 0); } else if(*cmd == '#') { sendpkt("busbs", MSG_CHANNEL_REQUEST, send.chan, "subsystem", 9, 0, cmd+1, strlen(cmd)-1); } else { sendpkt("busbs", MSG_CHANNEL_REQUEST, send.chan, "exec", 4, 0, cmd, strlen(cmd)); } Mux: notify(catch); atexit(shutdown); recv.pid = getpid(); n = rfork(RFPROC|RFMEM); if(n < 0) sysfatal("fork: %r"); /* parent reads and dispatches packets */ if(n > 0) { send.pid = n; while(recv.eof == 0){ recvpkt(); qlock(&sl); dispatch(); if((int)(send.kex - send.seq) <= 0 || (int)(recv.kex - recv.seq) <= 0) kex(0); qunlock(&sl); } exits(status); } /* child reads input and sends packets */ qlock(&sl); for(;;){ static uchar buf[MaxPacket]; qunlock(&sl); n = read(0, buf, send.pkt); qlock(&sl); if(send.eof) break; if(n < 0 && wasintr()) intr = 1; if(intr){ if(!raw) break; if(getdim()){ sendpkt("busbuuuu", MSG_CHANNEL_REQUEST, send.chan, "window-change", 13, 0, tty.cols, tty.lines, tty.xpixels, tty.ypixels); }else{ sendpkt("busbs", MSG_CHANNEL_REQUEST, send.chan, "signal", 6, 0, "INT", 3); } intr = 0; continue; } if(n <= 0) break; if(mux){ sendpkt("[", buf, n); continue; } send.win -= n; while(send.win < 0) rsleep(&send); sendpkt("bus", MSG_CHANNEL_DATA, send.chan, buf, n); } if(send.eof++ == 0 && !mux) sendpkt("bu", raw ? MSG_CHANNEL_CLOSE : MSG_CHANNEL_EOF, send.chan); else if(recv.pid > 0 && mux) postnote(PNPROC, recv.pid, "shutdown"); qunlock(&sl); exits(nil); }