git: plan9front

ref: ad37339a1c39c90f51d273abe60f552d3900f752
dir: /sys/src/cmd/dtracy/type.c/

View raw version
#include <u.h>
#include <libc.h>
#include <ctype.h>
#include <dtracy.h>
#include <bio.h>
#include "dat.h"
#include "fns.h"

Node *
icast(int sign, int size, Node *n)
{
	Type *t;
	
	t = type(TYPINT, sign, size);
	return node(OCAST, t, n);
}

/*
	the type checker checks types.
	the result is an expression that is correct if evaluated with 64-bit operands all the way.
	to maintain c-like semantics, this means adding casts all over the place, which will get optimised later.
	
	note we use kencc, NOT ansi c, semantics for unsigned.
*/

Node *
typecheck(Node *n)
{
	int s1, s2, sign;

	switch(/*nodetype*/n->type){
	case OSYM:
		switch(n->sym->type){
		case SYMNONE: error("undeclared '%s'", n->sym->name); break;
		case SYMVAR: n->typ = n->sym->typ; break;
		default: sysfatal("typecheck: unknown symbol type %d", n->sym->type);
		}
		break;
	case ONUM:
		if((vlong)n->num >= -0x80000000LL && (vlong)n->num <= 0x7fffffffLL)
			n->typ = type(TYPINT, 4, 1);
		else
			n->typ = type(TYPINT, 8, 1);
		break;
	case OSTR:
		n->typ = type(TYPSTRING);
		break;
	case OBIN:
		n->n1 = typecheck(n->n1);
		n->n2 = typecheck(n->n2);
		if(n->n1->typ == nil || n->n2->typ == nil)
			break;
		if(n->n1->typ->type != TYPINT){
			error("%τ not allowed in operation", n->n1->typ);
			break;
		}
		if(n->n2->typ->type != TYPINT){
			error("%τ not allowed in operation", n->n2->typ);
			break;
		}
		s1 = n->n1->typ->size;
		s2 = n->n2->typ->size;
		sign = n->n1->typ->sign && n->n2->typ->sign;
		switch(n->op){
		case OPADD:
		case OPSUB:
		case OPMUL:
		case OPDIV:
		case OPMOD:
		case OPAND:
		case OPOR:
		case OPXOR:
		case OPXNOR:
			n->typ = type(TYPINT, 8, sign);
			if(s1 > 4 || s2 > 4){
				n->n1 = icast(8, sign, n->n1);
				n->n2 = icast(8, sign, n->n2);
				return n;
			}else{
				n->n1 = icast(4, sign, n->n1);
				n->n2 = icast(4, sign, n->n2);
				return icast(4, sign, n);
			}
		case OPEQ:
		case OPNE:
		case OPLT:
		case OPLE:
			n->typ = type(TYPINT, 4, sign);
			if(s1 > 4 || s2 > 4){
				n->n1 = icast(8, sign, n->n1);
				n->n2 = icast(8, sign, n->n2);
				return n;
			}else{
				n->n1 = icast(4, sign, n->n1);
				n->n2 = icast(4, sign, n->n2);
				return n;
			}
		case OPLAND:
		case OPLOR:
			n->typ = type(TYPINT, 4, sign);
			return n;
		case OPLSH:
		case OPRSH:
			if(n->n1->typ->size <= 4)
				n->n1 = icast(4, n->n1->typ->sign, n->n1);
			n->typ = n->n1->typ;
			return icast(n->typ->size, n->typ->sign, n);
		default:
			sysfatal("typecheck: unknown op %d", n->op);
		}
		break;
	case OCAST:
		n->n1 = typecheck(n->n1);
		if(n->n1->typ == nil)
			break;
		if(n->typ->type == TYPINT && n->n1->typ->type == TYPINT){
		}else if(n->typ == n->n1->typ){
		}else if(n->typ->type == TYPSTRING && n->n1->typ->type == TYPINT){
		}else
			error("can't cast from %τ to %τ", n->n1->typ, n->typ);
		break;
	case OLNOT:
		n->n1 = typecheck(n->n1);
		if(n->n1->typ == nil)
			break;
		if(n->n1->typ->type != TYPINT){
			error("%τ not allowed in operation", n->n1->typ);
			break;
		}
		n->typ = type(TYPINT, 4, 1);
		break;
	case OTERN:
		n->n1 = typecheck(n->n1);
		n->n2 = typecheck(n->n2);
		n->n3 = typecheck(n->n3);
		if(n->n1->typ == nil || n->n2->typ == nil || n->n3->typ == nil)
			break;
		if(n->n1->typ->type != TYPINT){
			error("%τ not allowed in operation", n->n1->typ);
			break;
		}
		if(n->n2->typ->type == TYPINT || n->n3->typ->type == TYPINT){
			sign = n->n2->typ->sign && n->n3->typ->sign;
			s1 = n->n2->typ->size;
			s2 = n->n3->typ->size;
			if(s1 > 4 || s2 > 4){
				n->n2 = icast(8, sign, n->n2);
				n->n3 = icast(8, sign, n->n3);
				n->typ = type(TYPINT, 8, sign);
				return n;
			}else{
				n->n2 = icast(4, sign, n->n2);
				n->n3 = icast(4, sign, n->n3);
				n->typ = type(TYPINT, 4, sign);
				return n;
			}
		}else if(n->n2->typ == n->n3->typ){
			n->typ = n->n2->typ;
		}else
			error("don't know how to do ternary with %τ and %τ", n->n2->typ, n->n3->typ);
		break;
	case ORECORD:
	default: sysfatal("typecheck: unknown node type %α", n->type);
	}
	return n;
}

vlong
evalop(int op, int sign, vlong v1, vlong v2)
{
	switch(/*oper*/op){
	case OPADD: return v1 + v2; break;
	case OPSUB: return v1 - v2; break;
	case OPMUL: return v1 * v2; break;
	case OPDIV: if(v2 == 0) sysfatal("division by zero"); return sign ? v1 / v2 : (uvlong)v1 / (uvlong)v2; break;
	case OPMOD: if(v2 == 0) sysfatal("division by zero"); return sign ? v1 % v2 : (uvlong)v1 % (uvlong)v2; break;
	case OPAND: return v1 & v2; break;
	case OPOR: return v1 | v2; break;
	case OPXOR: return v1 ^ v2; break;
	case OPXNOR: return ~(v1 ^ v2); break;
	case OPLSH:
		if((u64int)v2 >= 64)
			return 0;
		else
			return v1 << v2;
		break;
	case OPRSH:
		if(sign){
			if((u64int)v2 >= 64)
				return v1 >> 63;
			else
				return v1 >> v2;
		}else{
			if((u64int)v2 >= 64)
				return 0;
			else
				return (u64int)v1 >> v2;
		}
		break;
	case OPEQ: return v1 == v2; break;
	case OPNE: return v1 != v2; break;
	case OPLT: return v1 < v2; break;
	case OPLE: return v1 <= v2; break;
	case OPLAND: return v1 && v2; break;
	case OPLOR: return v1 || v2; break;
	default:
		sysfatal("cfold: unknown op %.2x", op);
		return 0;
	}

}

Node *
addtype(Type *t, Node *n)
{
	n->typ = t;
	return n;
}

/* fold constants */

static Node *
cfold(Node *n)
{
	switch(/*nodetype*/n->type){
	case ONUM:
	case OSYM:
	case OSTR:
		return n;
	case OBIN:
		n->n1 = cfold(n->n1);
		n->n2 = cfold(n->n2);
		if(n->n1->type != ONUM || n->n2->type != ONUM)
			return n;
		return addtype(n->typ, node(ONUM, evalop(n->op, n->typ->sign, n->n1->num, n->n2->num)));
	case OLNOT:
		n->n1 = cfold(n->n1);
		if(n->n1->type == ONUM)
			return addtype(n->typ, node(ONUM, !n->n1->num));
		return n;
	case OTERN:
		n->n1 = cfold(n->n1);
		n->n2 = cfold(n->n2);
		n->n3 = cfold(n->n3);
		if(n->n1->type == ONUM)
			return n->n1->num ? n->n2 : n->n3;
		return n;
	case OCAST:
		n->n1 = cfold(n->n1);
		if(n->n1->type != ONUM || n->typ->type != TYPINT)
			return n;
		switch(n->typ->size << 4 | n->typ->sign){
		case 0x10: return addtype(n->typ, node(ONUM, (vlong)(u8int)n->n1->num));
		case 0x11: return addtype(n->typ, node(ONUM, (vlong)(s8int)n->n1->num));
		case 0x20: return addtype(n->typ, node(ONUM, (vlong)(u16int)n->n1->num));
		case 0x21: return addtype(n->typ, node(ONUM, (vlong)(s16int)n->n1->num));
		case 0x40: return addtype(n->typ, node(ONUM, (vlong)(u32int)n->n1->num));
		case 0x41: return addtype(n->typ, node(ONUM, (vlong)(s32int)n->n1->num));
		case 0x80: return addtype(n->typ, node(ONUM, n->n1->num));
		case 0x81: return addtype(n->typ, node(ONUM, n->n1->num));
		}
		return n;
	case ORECORD:
	default:
		fprint(2, "cfold: unknown type %α\n", n->type);
		return n;
	}
}

/* calculate the minimum record size for each node of the expression */
static Node *
calcrecsize(Node *n)
{
	switch(/*nodetype*/n->type){
	case ONUM:
	case OSTR:
		n->recsize = 0;
		break;
	case OSYM:
		switch(n->sym->type){
		case SYMVAR:
			switch(n->sym->idx){
			case DTV_TIME:
			case DTV_PROBE:
				n->recsize = 0;
				break;
			default:
				n->recsize = n->typ->size;
				break;
			}
			break;
		default: sysfatal("calcrecsize: unknown symbol type %d", n->sym->type); return nil;
		}
		break;
	case OBIN:
		n->n1 = calcrecsize(n->n1);
		n->n2 = calcrecsize(n->n2);
		n->recsize = min(n->typ->size, n->n1->recsize + n->n2->recsize);
		break;
	case OLNOT:
		n->n1 = calcrecsize(n->n1);
		n->recsize = min(n->typ->size, n->n1->recsize);
		break;
	case OCAST:
		n->n1 = calcrecsize(n->n1);
		if(n->typ->type == TYPSTRING)
			n->recsize = n->typ->size;
		else
			n->recsize = min(n->typ->size, n->n1->recsize);
		break;
	case OTERN:
		n->n1 = calcrecsize(n->n1);
		n->n2 = calcrecsize(n->n2);
		n->n3 = calcrecsize(n->n3);
		n->recsize = min(n->typ->size, n->n1->recsize + n->n2->recsize + n->n3->recsize);
		break;
	case ORECORD:
	default: sysfatal("calcrecsize: unknown type %α", n->type); return nil;
	}
	return n;
}

/* insert ORECORD nodes to mark the subexpression that we will pass to the kernel */
static Node *
insrecord(Node *n)
{
	if(n->recsize == 0)
		return n;
	if(n->typ->size == n->recsize)
		return addtype(n->typ, node(ORECORD, n));
	switch(/*nodetype*/n->type){
	case ONUM:
	case OSTR:
	case OSYM:
		break;
	case OBIN:
		n->n1 = insrecord(n->n1);
		n->n2 = insrecord(n->n2);
		break;
	case OLNOT:
	case OCAST:
		n->n1 = insrecord(n->n1);
		break;
	case OTERN:
		n->n1 = insrecord(n->n1);
		n->n2 = insrecord(n->n2);
		n->n3 = insrecord(n->n3);
		break;
	case ORECORD:
	default: sysfatal("insrecord: unknown type %α", n->type); return nil;
	}
	return n;
}

/*
	delete useless casts.
	going down we determine the number of bits (m) needed to be correct at each stage.
	going back up we determine the number of bits (n->databits) which can be either 0 or 1.
	all other bits are either zero (n->upper == UPZX) or sign-extended (n->upper == UPSX).
	note that by number of bits we always mean a consecutive block starting from the LSB.
	
	we can delete a cast if it either affects only bits not needed (according to m) or
	if it's a no-op (according to databits, upper).
*/
static Node *
elidecasts(Node *n, int m)
{
	switch(/*nodetype*/n->type){
	case OSTR:
		return n;
	case ONUM:
		n->databits = n->typ->size * 8;
		n->upper = n->typ->sign ? UPSX : UPZX;
		break;
	case OSYM:
		/* TODO: make less pessimistic */
		n->databits = 64;
		break;
	case OBIN:
		switch(/*oper*/n->op){
		case OPADD:
		case OPSUB:
			n->n1 = elidecasts(n->n1, m);
			n->n2 = elidecasts(n->n2, m);
			n->databits = min(64, max(n->n1->databits, n->n2->databits) + 1);
			n->upper = n->n1->upper | n->n2->upper;
			break;
		case OPMUL:
			n->n1 = elidecasts(n->n1, m);
			n->n2 = elidecasts(n->n2, m);
			n->databits = min(64, n->n1->databits + n->n2->databits);
			n->upper = n->n1->upper | n->n2->upper;
			break;
		case OPAND:
		case OPOR:
		case OPXOR:
		case OPXNOR:
			n->n1 = elidecasts(n->n1, m);
			n->n2 = elidecasts(n->n2, m);
			if(n->op == OPAND && (n->n1->upper == UPZX || n->n2->upper == UPZX)){
				n->upper = UPZX;
				if(n->n1->upper == UPZX && n->n2->upper == UPZX)
					n->databits = min(n->n1->databits, n->n2->databits);
				else if(n->n1->upper == UPZX)
					n->databits = n->n1->databits;
				else
					n->databits = n->n2->databits;
			}else{
				n->databits = max(n->n1->databits, n->n2->databits);
				n->upper = n->n1->upper | n->n2->upper;
			}
			break;
		case OPLSH:
			n->n1 = elidecasts(n->n1, m);
			n->n2 = elidecasts(n->n2, 64);
			if(n->n2->type == ONUM && n->n2->num >= 0 && n->n1->databits + (uvlong)n->n2->num <= 64)
				n->databits = n->n1->databits + n->n2->num;
			else
				n->databits = 64;
			n->upper = n->n1->upper;
			break;
		case OPRSH:
			n->n1 = elidecasts(n->n1, 64);
			n->n2 = elidecasts(n->n2, 64);
			if(n->n1->upper == n->typ->sign){
				n->databits = n->n1->databits;
				n->upper = n->n1->upper;
			}else{
				n->databits = 64;
				n->upper = UPZX;
			}
			break;
		case OPEQ:
		case OPNE:
		case OPLT:
		case OPLE:
		case OPLAND:
		case OPLOR:
			n->n1 = elidecasts(n->n1, 64);
			n->n2 = elidecasts(n->n2, 64);
			n->databits = 1;
			n->upper = UPZX;
			break;
		case OPDIV:
		case OPMOD:
		default:
			n->n1 = elidecasts(n->n1, 64);
			n->n2 = elidecasts(n->n2, 64);
			n->databits = 64;
			n->upper = UPZX;
			break;
		}
		break;
	case OLNOT:
		n->n1 = elidecasts(n->n1, 64);
		n->databits = 1;
		n->upper = UPZX;
		break;
	case OCAST:
		switch(n->typ->type){
		case TYPINT:
			n->n1 = elidecasts(n->n1, min(n->typ->size * 8, m));
			if(n->n1->databits < n->typ->size * 8 && n->n1->upper == n->typ->sign){
				n->databits = n->n1->databits;
				n->upper = n->n1->upper;
			}else{
				n->databits = n->typ->size * 8;
				n->upper = n->typ->sign ? UPSX : UPZX;
			}
			if(n->typ->size * 8 >= m) return n->n1;
			if(n->typ->size * 8 >= n->n1->databits && n->typ->sign == n->n1->upper) return n->n1;
			if(n->typ->size * 8 > n->n1->databits && n->typ->sign && !n->n1->upper) return n->n1;
			break;
		case TYPSTRING:
			n->n1 = elidecasts(n->n1, 64);
			break;
		default:
			sysfatal("elidecasts: don't know how to cast %τ to %τ", n->n1->typ, n->typ);
		}
		break;
	case ORECORD:
		n->n1 = elidecasts(n->n1, min(n->typ->size * 8, m));
		if(n->n1->databits < n->typ->size * 8 && n->n1->upper == n->typ->sign){
			n->databits = n->n1->databits;
			n->upper = n->n1->upper;
		}else{
			n->databits = n->typ->size * 8;
			n->upper = n->typ->sign ? UPSX : UPZX;
		}
		break;
	case OTERN:
		n->n1 = elidecasts(n->n1, 64);
		n->n2 = elidecasts(n->n2, m);
		n->n3 = elidecasts(n->n3, m);
		if(n->n2->upper == n->n3->upper){
			n->databits = max(n->n2->databits, n->n3->databits);
			n->upper = n->n2->upper;
		}else{
			if(n->n3->upper == UPSX)
				n->databits = max(min(64, n->n2->databits + 1), n->n3->databits);
			else
				n->databits = max(min(64, n->n3->databits + 1), n->n2->databits);
			n->upper = UPSX;
		}
		break;
	default: sysfatal("elidecasts: unknown type %α", n->type);
	}
//	print("need %d got %d%c %ε\n", n->needbits, n->databits, "ZS"[n->upper], n);
	return n;
}


Node *
exprcheck(Node *n, int pred)
{
	if(dflag) print("start       %ε\n", n);
	n = typecheck(n);
	if(errors) return n;
	if(dflag) print("typecheck   %ε\n", n);
	n = cfold(n);
	if(dflag) print("cfold       %ε\n", n);
	if(!pred){
		n = insrecord(calcrecsize(n));
		if(dflag) print("insrecord   %ε\n", n);
	}
	n = elidecasts(n, 64);
	if(dflag) print("elidecasts  %ε\n", n);
	return n;
}