ref: 763101974bfb4a1ce622866a1e6a0032ba1489e8
dir: /sys/src/cmd/mpc.y/
%{ #include <u.h> #include <libc.h> #include <bio.h> #include <mp.h> typedef struct Sym Sym; typedef struct Node Node; enum { FSET = 1, FUSE = 2, FARG = 4, FLOC = 8, }; struct Sym { Sym* l; int f; char n[]; }; struct Node { int c; Node* l; Node* r; Sym* s; mpint* m; int n; }; #pragma varargck type "N" Node* int ntmp; Node *ftmps, *atmps; Node *modulo; Node* new(int, Node*, Node*); Sym* sym(char*); Biobuf bin; int goteof; int lineno; int clevel; char* filename; int getch(void); void ungetc(void); void yyerror(char*); int yyparse(void); void diag(Node*, char*, ...); void com(Node*); void fcom(Node*,Node*,Node*); #pragma varargck argpos cprint 1 #pragma varargck argpos diag 2 %} %union { Sym* sval; Node* node; } %type <node> name num args expr bool block elif stmnt stmnts %left '{' '}' ';' %right '=' ',' %right '?' ':' %left EQ NEQ '<' '>' %left LSH RSH %left '+' '-' %left '/' '%' %left '*' %left '^' %right '(' %token MOD IF ELSE WHILE BREAK %token <sval> NAME NUM %% prog: prog func | func func: name args stmnt { fcom($1, $2, $3); } args: '(' expr ')' { $$ = $2; } | '(' ')' { $$ = nil; } name: NAME { $$ = new(NAME,nil,nil); $$->s = $1; } num: NUM { $$ = new(NUM,nil,nil); $$->s = $1; } elif: ELSE IF '(' bool ')' stmnt { $$ = new('?', $4, new(':', $6, nil)); } | ELSE IF '(' bool ')' stmnt elif { $$ = new('?', $4, new(':', $6, $7)); } | ELSE stmnt { $$ = $2; } sem: sem ';' | ';' stmnt: expr '=' expr sem { $$ = new('=', $1, $3); } | MOD args stmnt { $$ = new('m', $2, $3); } | IF '(' bool ')' stmnt { $$ = new('?', $3, new(':', $5, nil)); } | IF '(' bool ')' stmnt elif { $$ = new('?', $3, new(':', $5, $6)); } | WHILE '(' bool ')' stmnt { $$ = new('@', new('?', $3, new(':', $5, new('b', nil, nil))), nil); } | BREAK sem { $$ = new('b', nil, nil); } | expr sem { if($1->c == NAME) $$ = new('e', $1, nil); else $$ = $1; } | block block: '{' stmnts '}' { $$ = $2; } stmnts: stmnts stmnt { $$ = new('\n', $1, $2); } | stmnt expr: '(' expr ')' { $$ = $2; } | name { $$ = $1; } | num { $$ = $1; } | '-' expr { $$ = new(NUM, nil, nil); $$->s = sym("0"); $$->s->f = 0; $$ = new('-', $$, $2); } | expr ',' expr { $$ = new(',', $1, $3); } | expr '^' expr { $$ = new('^', $1, $3); } | expr '*' expr { $$ = new('*', $1, $3); } | expr '/' expr { $$ = new('/', $1, $3); } | expr '%' expr { $$ = new('%', $1, $3); } | expr '+' expr { $$ = new('+', $1, $3); } | expr '-' expr { $$ = new('-', $1, $3); } | bool '?' expr ':' expr { $$ = new('?', $1, new(':', $3, $5)); } | name args { $$ = new('e', $1, $2); } | expr LSH expr { $$ = new(LSH, $1, $3); } | expr RSH expr { $$ = new(RSH, $1, $3); } bool: '(' bool ')' { $$ = $2; } | '!' bool { $$ = new('!', $2, nil); } | expr EQ expr { $$ = new(EQ, $1, $3); } | expr NEQ expr { $$ = new('!', new(EQ, $1, $3), nil); } | expr '>' expr { $$ = new('>', $1, $3); } | expr '<' expr { $$ = new('<', $1, $3); } %% int yylex(void) { static char buf[200]; char *p; int c; Loop: c = getch(); switch(c){ case -1: return -1; case ' ': case '\t': case '\n': goto Loop; case '#': while((c = getch()) > 0) if(c == '\n') break; goto Loop; } switch(c){ case '?': case ':': case '+': case '-': case '*': case '^': case '/': case '%': case '{': case '}': case '(': case ')': case ',': case ';': return c; case '<': if(getch() == '<') return LSH; ungetc(); return '<'; case '>': if(getch() == '>') return RSH; ungetc(); return '>'; case '=': if(getch() == '=') return EQ; ungetc(); return '='; case '!': if(getch() == '=') return NEQ; ungetc(); return '!'; } ungetc(); p = buf; for(;;){ c = getch(); if((c >= Runeself) || (c == '_') || (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9')){ *p++ = c; continue; } ungetc(); break; } *p = '\0'; if(strcmp(buf, "mod") == 0) return MOD; if(strcmp(buf, "if") == 0) return IF; if(strcmp(buf, "else") == 0) return ELSE; if(strcmp(buf, "while") == 0) return WHILE; if(strcmp(buf, "break") == 0) return BREAK; yylval.sval = sym(buf); yylval.sval->f = 0; return (buf[0] >= '0' && buf[0] <= '9') ? NUM : NAME; } int getch(void) { int c; c = Bgetc(&bin); if(c == Beof){ goteof = 1; return -1; } if(c == '\n') lineno++; return c; } void ungetc(void) { Bungetc(&bin); } Node* new(int c, Node *l, Node *r) { Node *n; n = malloc(sizeof(Node)); n->c = c; n->l = l; n->r = r; n->s = nil; n->m = nil; n->n = lineno; return n; } Sym* sym(char *n) { static Sym *tab[128]; Sym *s; ulong h, t; int i; h = 0; for(i=0; n[i] != '\0'; i++){ t = h & 0xf8000000; h <<= 5; h ^= t>>27; h ^= (ulong)n[i]; } h %= nelem(tab); for(s = tab[h]; s != nil; s = s->l) if(strcmp(s->n, n) == 0) return s; s = malloc(sizeof(Sym)+i+1); memmove(s->n, n, i+1); s->f = 0; s->l = tab[h]; tab[h] = s; return s; } void yyerror(char *s) { fprint(2, "%s:%d: %s\n", filename, lineno, s); exits(s); } void cprint(char *fmt, ...) { static char buf[1024], tabs[] = "\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t"; char *p, *x; va_list a; va_start(a, fmt); vsnprint(buf, sizeof(buf), fmt, a); va_end(a); p = buf; while((x = strchr(p, '\n')) != nil){ x++; write(1, p, x-p); p = &tabs[sizeof(tabs)-1 - clevel]; if(*p != '\0') write(1, p, strlen(p)); p = x; } if(*p != '\0') write(1, p, strlen(p)); } Node* alloctmp(void) { Node *t; t = ftmps; if(t != nil) ftmps = t->l; else { char n[16]; snprint(n, sizeof(n), "tmp%d", ++ntmp); t = new(NAME, nil, nil); t->s = sym(n); cprint("mpint *"); } cprint("%N = mpnew(0);\n", t); t->s->f &= ~(FSET|FUSE); t->l = atmps; atmps = t; return t; } int isconst(Node *n) { if(n->c == NUM) return 1; if(n->c == NAME){ return n->s == sym("mpzero") || n->s == sym("mpone") || n->s == sym("mptwo"); } return 0; } int istmp(Node *n) { Node *l; if(n->c == NAME){ for(l = atmps; l != nil; l = l->l){ if(l->s == n->s) return 1; } } return 0; } void freetmp(Node *t) { Node **ll, *l; if(t == nil) return; if(t->c == ','){ freetmp(t->l); freetmp(t->r); return; } if(t->c != NAME) return; ll = &atmps; for(l = atmps; l != nil; l = l->l){ if(l == t){ cprint("mpfree(%N);\n", t); *ll = t->l; t->l = ftmps; ftmps = t; return; } ll = &l->l; } } int symref(Node *n, Sym *s) { if(n == nil) return 0; if(n->c == NAME && n->s == s) return 1; return symref(n->l, s) || symref(n->r, s); } void nodeset(Node *n) { if(n == nil) return; if(n->c == NAME){ n->s->f |= FSET; return; } if(n->c == ','){ nodeset(n->l); nodeset(n->r); } } int complex(Node *n) { if(n->c == NAME) return 0; if(n->c == NUM && n->m->sign > 0 && mpcmp(n->m, mptwo) <= 0) return 0; return 1; } void bcom(Node *n, Node *t); Node* ccom(Node *f) { Node *l, *r; if(f == nil) return nil; if(f->m != nil) return f; f->m = (void*)~0; switch(f->c){ case NUM: f->m = strtomp(f->s->n, nil, 0, nil); if(f->m == nil) diag(f, "bad constant"); goto out; case LSH: case RSH: break; case '+': case '-': case '*': case '/': case '%': case '^': if(modulo == nil || modulo->c == NUM) break; /* wet floor */ default: return f; } f->l = l = ccom(f->l); f->r = r = ccom(f->r); if(l == nil || r == nil || l->c != NUM || r->c != NUM) return f; f->m = mpnew(0); switch(f->c){ case LSH: case RSH: if(mpsignif(r->m) > 32) diag(f, "bad shift"); if(f->c == LSH) mpleft(l->m, mptoi(r->m), f->m); else mpright(l->m, mptoi(r->m), f->m); goto out; case '+': mpadd(l->m, r->m, f->m); break; case '-': mpsub(l->m, r->m, f->m); break; case '*': mpmul(l->m, r->m, f->m); break; case '/': if(modulo != nil){ mpinvert(r->m, modulo->m, f->m); mpmul(f->m, l->m, f->m); } else { mpdiv(l->m, r->m, f->m, nil); } break; case '%': mpmod(l->m, r->m, f->m); break; case '^': mpexp(l->m, r->m, modulo != nil ? modulo->m : nil, f->m); goto out; } if(modulo != nil) mpmod(f->m, modulo->m, f->m); out: f->l = nil; f->r = nil; f->s = nil; f->c = NUM; return f; } Node* ecom(Node *f, Node *t) { Node *l, *r, *t2; if(f == nil) return nil; f = ccom(f); if(f->c == NUM){ if(f->m->sign < 0){ f->m->sign = 1; t = ecom(f, t); f->m->sign = -1; if(isconst(t)) t = ecom(t, alloctmp()); cprint("%N->sign = -1;\n", t); return t; } if(mpcmp(f->m, mpzero) == 0){ f->c = NAME; f->s = sym("mpzero"); f->s->f = FSET; return ecom(f, t); } if(mpcmp(f->m, mpone) == 0){ f->c = NAME; f->s = sym("mpone"); f->s->f = FSET; return ecom(f, t); } if(mpcmp(f->m, mptwo) == 0){ f->c = NAME; f->s = sym("mptwo"); f->s->f = FSET; return ecom(f, t); } } if(f->c == ','){ if(t != nil) diag(f, "cannot assign list to %N", t); f->l = ecom(f->l, nil); f->r = ecom(f->r, nil); return f; } l = r = nil; if(f->c == NAME){ if((f->s->f & FSET) == 0) diag(f, "name used but not set"); f->s->f |= FUSE; if(t == nil) return f; if(f->s != t->s) cprint("mpassign(%N, %N);\n", f, t); goto out; } if(t == nil) t = alloctmp(); if(f->c == '?'){ bcom(f, t); goto out; } if(f->c == 'e'){ r = ecom(f->r, nil); if(r == nil) cprint("%N(%N);\n", f->l, t); else cprint("%N(%N, %N);\n", f->l, r, t); goto out; } if(t->c != NAME) diag(f, "destination %N not a name", t); switch(f->c){ case NUM: if(mpsignif(f->m) <= 32) cprint("uitomp(%udUL, %N);\n", mptoui(f->m), t); else if(mpsignif(f->m) <= 64) cprint("uvtomp(%lludULL, %N);\n", mptouv(f->m), t); else cprint("strtomp(\"%.16B\", nil, 16, %N);\n", f->m, t); goto out; case LSH: case RSH: r = ccom(f->r); if(r == nil || r->c != NUM || mpsignif(r->m) > 32) diag(f, "bad shift"); l = f->l->c == NAME ? f->l : ecom(f->l, t); if(f->c == LSH) cprint("mpleft(%N, %d, %N);\n", l, mptoi(r->m), t); else cprint("mpright(%N, %d, %N);\n", l, mptoi(r->m), t); goto out; case '*': case '/': l = ecom(f->l, nil); r = ecom(f->r, nil); break; default: l = ccom(f->l); r = ccom(f->r); l = ecom(l, complex(l) && !symref(r, t->s) ? t : nil); r = ecom(r, complex(r) && l->s != t->s ? t : nil); break; } if(modulo != nil){ switch(f->c){ case '+': cprint("mpmodadd(%N, %N, %N, %N);\n", l, r, modulo, t); goto out; case '-': cprint("mpmodsub(%N, %N, %N, %N);\n", l, r, modulo, t); goto out; case '*': Modmul: if(l->s == sym("mptwo") || r->s == sym("mptwo")) cprint("mpmodadd(%N, %N, %N, %N); // 2*%N\n", r->s == sym("mptwo") ? l : r, r->s == sym("mptwo") ? l : r, modulo, t, r); else cprint("mpmodmul(%N, %N, %N, %N);\n", l, r, modulo, t); goto out; case '/': if(l->s == sym("mpone")){ cprint("mpinvert(%N, %N, %N);\n", r, modulo, t); goto out; } t2 = alloctmp(); cprint("mpinvert(%N, %N, %N);\n", r, modulo, t2); cprint("mpmodmul(%N, %N, %N, %N);\n", l, t2, modulo, t); freetmp(t2); goto out; case '^': if(r->s == sym("mptwo")){ r = l; goto Modmul; } cprint("mpexp(%N, %N, %N, %N);\n", l, r, modulo, t); goto out; } } switch(f->c){ case '+': cprint("mpadd(%N, %N, %N);\n", l, r, t); goto out; case '-': if(l->s == sym("mpzero")){ r = ecom(r, t); cprint("%N->sign = -%N->sign;\n", t, t); } else cprint("mpsub(%N, %N, %N);\n", l, r, t); goto out; case '*': Mul: if(l->s == sym("mptwo") || r->s == sym("mptwo")) cprint("mpleft(%N, 1, %N);\n", r->s == sym("mptwo") ? l : r, t); else cprint("mpmul(%N, %N, %N);\n", l, r, t); goto out; case '/': cprint("mpdiv(%N, %N, %N, %N);\n", l, r, t, nil); goto out; case '%': cprint("mpmod(%N, %N, %N);\n", l, r, t); goto out; case '^': if(r->s == sym("mptwo")){ r = l; goto Mul; } cprint("mpexp(%N, %N, nil, %N);\n", l, r, t); goto out; default: diag(f, "unknown operation"); } out: if(l != t) freetmp(l); if(r != t) freetmp(r); nodeset(t); return t; } void bcom(Node *n, Node *t) { Node *f, *l, *r; int neg = 0; l = r = nil; f = n->l; Loop: switch(f->c){ case '!': neg = !neg; f = f->l; goto Loop; case '>': case '<': case EQ: l = ecom(f->l, nil); r = ecom(f->r, nil); if(t != nil) { Node *b1, *b2; b1 = ecom(n->r->l, nil); b2 = ecom(n->r->r, nil); cprint("mpsel("); if(l->s == r->s) cprint("0"); else { if(f->c == '>') cprint("-"); cprint("mpcmp(%N, %N)", l, r); } if(f->c == EQ) neg = !neg; else cprint(" >> (sizeof(int)*8-1)"); cprint(", %N, %N, %N);\n", neg ? b2 : b1, neg ? b1 : b2, t); freetmp(b1); freetmp(b2); } else { cprint("if("); if(l->s == r->s) cprint("0"); else cprint("mpcmp(%N, %N)", l, r); if(f->c == EQ) cprint(neg ? " != 0" : " == 0"); else if(f->c == '>') cprint(neg ? " <= 0" : " > 0"); else cprint(neg ? " >= 0" : " < 0"); cprint(")"); com(n->r); } break; default: diag(n, "saw %N in boolean expression", f); } freetmp(l); freetmp(r); } void com(Node *n) { Node *l, *r; Loop: if(n != nil) switch(n->c){ case '\n': com(n->l); n = n->r; goto Loop; case '?': bcom(n, nil); break; case 'b': for(l = atmps; l != nil; l = l->l) cprint("mpfree(%N);\n", l); cprint("break;\n"); break; case '@': cprint("for(;;)"); case ':': clevel++; cprint("{\n"); l = ftmps; r = atmps; if(n->c == '@') atmps = nil; ftmps = nil; com(n->l); if(n->r != nil){ cprint("}else{\n"); ftmps = nil; com(n->r); } ftmps = l; atmps = r; clevel--; cprint("}\n"); break; case 'm': l = modulo; modulo = ecom(n->l, nil); com(n->r); freetmp(modulo); modulo = l; break; case 'e': if(n->r == nil) cprint("%N();\n", n->l); else { r = ecom(n->r, nil); cprint("%N(%N);\n", n->l, r); freetmp(r); } break; case '=': ecom(n->r, n->l); break; } } Node* flocs(Node *n, Node *r) { Loop: if(n != nil) switch(n->c){ default: r = flocs(n->l, r); r = flocs(n->r, r); n = n->r; goto Loop; case '=': n = n->l; if(n == nil) diag(n, "lhs is nil"); while(n->c == ','){ n->c = '='; r = flocs(n, r); n->c = ','; n = n->r; if(n == nil) return r; } if(n->c == NAME && (n->s->f & (FARG|FLOC)) == 0){ n->s->f = FLOC; return new(',', n, r); } break; } return r; } void fcom(Node *f, Node *a, Node *b) { Node *a0, *l0, *l; ntmp = 0; ftmps = atmps = modulo = nil; clevel = 1; cprint("void %N(", f); a0 = a; while(a != nil){ if(a != a0) cprint(", "); l = a->c == NAME ? a : a->l; l->s->f = FARG|FSET; cprint("mpint *%N", l); a = a->r; } cprint("){\n"); l0 = flocs(b, nil); for(a = l0; a != nil; a = a->r) cprint("mpint *%N = mpnew(0);\n", a->l); com(b); for(a = l0; a != nil; a = a->r) cprint("mpfree(%N);\n", a->l); clevel = 0; cprint("}\n"); } void diag(Node *n, char *fmt, ...) { static char buf[1024]; va_list a; va_start(a, fmt); vsnprint(buf, sizeof(buf), fmt, a); va_end(a); fprint(2, "%s:%d: for %N; %s\n", filename, n->n, n, buf); exits("error"); } int Nfmt(Fmt *f) { Node *n = va_arg(f->args, Node*); if(n == nil) return fmtprint(f, "nil"); if(n->c == ',') return fmtprint(f, "%N, %N", n->l, n->r); switch(n->c){ case NUM: if(n->m != nil) return fmtprint(f, "%B", n->m); /* wet floor */ case NAME: return fmtprint(f, "%s", n->s->n); case EQ: return fmtprint(f, "=="); case IF: return fmtprint(f, "if"); case ELSE: return fmtprint(f, "else"); case MOD: return fmtprint(f, "mod"); default: return fmtprint(f, "%c", (char)n->c); } } void parse(int fd, char *file) { Binit(&bin, fd, OREAD); filename = file; clevel = 0; lineno = 1; goteof = 0; while(!goteof) yyparse(); Bterm(&bin); } void usage(void) { fprint(2, "%s [file ...]\n", argv0); exits("usage"); } void main(int argc, char *argv[]) { fmtinstall('N', Nfmt); fmtinstall('B', mpfmt); ARGBEGIN { default: usage(); } ARGEND; if(argc == 0){ parse(0, "<stdin>"); exits(nil); } while(*argv != nil){ int fd; if((fd = open(*argv, OREAD)) < 0){ fprint(2, "%s: %r\n", *argv); exits("error"); } parse(fd, *argv); close(fd); argv++; } exits(nil); }