use ariadne::{Color, Label, Report, ReportKind}; use std::collections::HashMap; use crate::ast::{self, MemSize}; use crate::intrinsics::Intrinsics; use crate::parser::{Sources, Span}; use ast::Type::*; type Result = std::result::Result; struct Var { span: Span, type_: ast::Type, mutable: bool, } type Vars = HashMap; pub fn tc_script(script: &mut ast::Script, sources: &Sources) -> Result<()> { let mut context = Context { sources, global_vars: HashMap::new(), functions: HashMap::new(), locals: ast::Locals::default(), local_vars: LocalVars::new(), block_stack: Vec::new(), return_type: None, intrinsics: Intrinsics::new(), }; let mut result = Ok(()); for import in &script.imports { match import.type_ { ast::ImportType::Variable { ref name, type_, mutable, } => { if let Some(Var { span, .. }) = context.global_vars.get(name) { result = report_duplicate_definition( "Global already defined", &import.span, span, sources, ); } else { context.global_vars.insert( name.clone(), Var { type_, span: import.span.clone(), mutable, }, ); } } ast::ImportType::Function { ref name, ref params, result: ref result_type, } => { if let Some(fnc) = context.functions.get(name) { result = report_duplicate_definition( "Function already defined", &import.span, &fnc.span, sources, ); } else { context.functions.insert( name.clone(), FunctionType { span: import.span.clone(), params: params.clone(), type_: *result_type, }, ); } } ast::ImportType::Memory(..) => (), } } for v in &mut script.global_vars { if let Some(Var { span, .. }) = context.global_vars.get(&v.name) { result = report_duplicate_definition("Global already defined", &v.span, span, sources); } else { tc_const(&mut v.value, sources)?; if v.type_ != v.value.type_ { if v.type_.is_some() { result = type_mismatch(v.type_, &v.span, v.value.type_, &v.value.span, sources); } else { v.type_ = v.value.type_; } } context.global_vars.insert( v.name.clone(), Var { type_: v.type_.unwrap(), span: v.span.clone(), mutable: v.mutable, }, ); } } for f in &script.functions { let params = f.params.iter().map(|(_, t)| *t).collect(); if let Some(fnc) = context.functions.get(&f.name) { result = report_duplicate_definition( "Function already defined", &f.span, &fnc.span, sources, ); } else { context.functions.insert( f.name.clone(), FunctionType { params, type_: f.type_, span: f.span.clone(), }, ); } } for f in &mut script.functions { context.local_vars.clear(); context.local_vars.push_scope(); for (name, type_) in &f.params { if let Some(span) = context .local_vars .get(name) .map(|id| &context.locals[id].span) .or_else(|| context.global_vars.get(name).map(|v| &v.span)) { result = report_duplicate_definition("Variable already defined", &f.span, span, sources); } else { context.local_vars.insert( name.clone(), context .locals .add_param(f.span.clone(), name.clone(), *type_), ); } } context.return_type = f.type_; tc_expression(&mut context, &mut f.body)?; let mut local_mapping: Vec<(ast::Type, usize)> = context .locals .locals .iter() .enumerate() .filter(|(_, local)| local.index.is_some()) .map(|(index, local)| (local.type_, index)) .collect(); local_mapping.sort_by_key(|&(t, _)| t); let locals_start = context.locals.params.len(); for (id, (_, index)) in local_mapping.into_iter().enumerate() { context.locals.locals[index].index = Some((locals_start + id) as u32); } f.locals = std::mem::take(&mut context.locals); if f.body.type_ != f.type_ { result = type_mismatch(f.type_, &f.span, f.body.type_, &f.body.span, sources); } } let mut start_function: Option<&ast::Function> = None; for f in &script.functions { if f.start { if !f.params.is_empty() || f.type_.is_some() { Report::build(ReportKind::Error, f.span.0, f.span.1.start) .with_message("Start function can't have params or a return value") .with_label( Label::new(f.span.clone()) .with_message("Start function can't have params or a return value") .with_color(Color::Red), ) .finish() .eprint(sources) .unwrap(); result = Err(()); } if let Some(prev) = start_function { result = report_duplicate_definition( "Start function already defined", &f.span, &prev.span, sources, ); } else { start_function = Some(f); } } } for data in &mut script.data { tc_const(&mut data.offset, sources)?; if data.offset.type_ != Some(I32) { result = type_mismatch( Some(I32), &data.offset.span, data.offset.type_, &data.offset.span, sources, ); } for values in &mut data.data { match values { ast::DataValues::Array { type_, values } => { let needed_type = match type_ { ast::DataType::I8 | ast::DataType::I16 | ast::DataType::I32 => { ast::Type::I32 } ast::DataType::I64 => ast::Type::I64, ast::DataType::F32 => ast::Type::F32, ast::DataType::F64 => ast::Type::F64, }; for value in values { tc_const(value, sources)?; if value.type_ != Some(needed_type) { result = type_mismatch( Some(needed_type), &value.span, value.type_, &value.span, sources, ); } } } ast::DataValues::String(_) | ast::DataValues::File { .. } => (), } } } result } struct FunctionType { span: Span, params: Vec, type_: Option, } struct Context<'a> { sources: &'a Sources, global_vars: Vars, functions: HashMap, locals: ast::Locals, local_vars: LocalVars, block_stack: Vec, return_type: Option, intrinsics: Intrinsics, } struct LocalVars(Vec>); impl LocalVars { fn new() -> LocalVars { LocalVars(Vec::new()) } fn get(&self, name: &str) -> Option { self.0 .iter() .rev() .filter_map(|scope| scope.get(name)) .next() .copied() } fn get_in_current(&self, name: &str) -> Option { self.0.last().unwrap().get(name).copied() } fn clear(&mut self) { self.0.clear(); } fn push_scope(&mut self) { self.0.push(HashMap::new()); } fn pop_scope(&mut self) { self.0.pop(); } fn insert(&mut self, name: String, id: u32) { self.0.last_mut().unwrap().insert(name, id); } } fn report_duplicate_definition( msg: &str, span: &Span, prev_span: &Span, sources: &Sources, ) -> Result<()> { Report::build(ReportKind::Error, span.0, span.1.start) .with_message(msg) .with_label( Label::new(span.clone()) .with_message(msg) .with_color(Color::Red), ) .with_label( Label::new(prev_span.clone()) .with_message("Previous definition was here") .with_color(Color::Yellow), ) .finish() .eprint(sources) .unwrap(); Err(()) } fn type_mismatch( type1: Option, span1: &Span, type2: Option, span2: &Span, sources: &Sources, ) -> Result<()> { Report::build(ReportKind::Error, span2.0, span2.1.start) .with_message("Type mismatch") .with_label( Label::new(span1.clone()) .with_message(format!( "Expected type {:?}...", type1 .map(|t| format!("{:?}", t)) .unwrap_or_else(|| "void".to_string()) )) .with_color(Color::Yellow), ) .with_label( Label::new(span2.clone()) .with_message(format!( "...but found type {}", type2 .map(|t| format!("{:?}", t)) .unwrap_or_else(|| "void".to_string()) )) .with_color(Color::Red), ) .finish() .eprint(sources) .unwrap(); Err(()) } pub fn report_error(msg: &str, span: &Span, sources: &Sources) -> Result<()> { Report::build(ReportKind::Error, span.0, span.1.start) .with_message(msg) .with_label( Label::new(span.clone()) .with_message(msg) .with_color(Color::Red), ) .finish() .eprint(sources) .unwrap(); Err(()) } fn expected_type(span: &Span, sources: &Sources) -> Result<()> { report_error( "Expected value but found expression of type void", span, sources, ) } fn unknown_variable(span: &Span, sources: &Sources) -> Result<()> { report_error("Unknown variable", span, sources) } fn immutable_assign(span: &Span, sources: &Sources) -> Result<()> { report_error("Trying to assign to immutable variable", span, sources) } fn missing_label(span: &Span, sources: &Sources) -> Result<()> { report_error("Label not found", span, sources) } fn tc_expression(context: &mut Context, expr: &mut ast::Expression) -> Result<()> { expr.type_ = match expr.expr { ast::Expr::Block { ref mut statements, ref mut final_expression, } => { context.local_vars.push_scope(); for stmt in statements { tc_expression(context, stmt)?; } let type_ = if let Some(final_expression) = final_expression { tc_expression(context, final_expression)?; final_expression.type_ } else { None }; context.local_vars.pop_scope(); type_ } ast::Expr::Let { ref mut value, ref mut type_, ref name, let_type, ref mut local_id, .. } => { if let Some(ref mut value) = value { tc_expression(context, value)?; if let Some(type_) = type_ { if Some(*type_) != value.type_ { return type_mismatch( Some(*type_), &expr.span, value.type_, &value.span, context.sources, ); } } else if value.type_.is_none() { return expected_type(&value.span, context.sources); } else { *type_ = value.type_; } } if let Some(type_) = type_ { let store = let_type != ast::LetType::Inline; let id = context .local_vars .get_in_current(name) .filter(|id| { let local = &context.locals[*id]; local.type_ == *type_ && store == local.index.is_some() }) .unwrap_or_else(|| { context .locals .add_local(expr.span.clone(), name.clone(), *type_, store) }); *local_id = Some(id); context.local_vars.insert(name.clone(), id); } else { return report_error("Type missing", &expr.span, context.sources); } None } ast::Expr::Peek(ref mut mem_location) => { tc_mem_location(context, mem_location)?; let ty = match mem_location.size { MemSize::Float => F32, _ => I32, }; Some(ty) } ast::Expr::Poke { ref mut mem_location, ref mut value, } => { tc_mem_location(context, mem_location)?; tc_expression(context, value)?; let ty = match mem_location.size { MemSize::Float => F32, _ => I32, }; if value.type_ != Some(ty) { return type_mismatch( Some(ty), &expr.span, value.type_, &value.span, context.sources, ); } None } ast::Expr::I32Const(_) => Some(ast::Type::I32), ast::Expr::I64Const(_) => Some(ast::Type::I64), ast::Expr::F32Const(_) => Some(ast::Type::F32), ast::Expr::F64Const(_) => Some(ast::Type::F64), ast::Expr::UnaryOp { op, ref mut value } => { tc_expression(context, value)?; if value.type_.is_none() { return expected_type(&value.span, context.sources); } use ast::Type::*; use ast::UnaryOp::*; Some(match (value.type_.unwrap(), op) { (t, Negate) => t, (I32 | I64, Not) => I32, (_, Not) => { return type_mismatch( Some(I32), &expr.span, value.type_, &value.span, context.sources, ) } }) } ast::Expr::BinOp { op, ref mut left, ref mut right, } => { tc_expression(context, left)?; tc_expression(context, right)?; if let Some(type_) = left.type_ { if left.type_ != right.type_ { return type_mismatch( Some(type_), &left.span, right.type_, &right.span, context.sources, ); } } else { return expected_type(&left.span, context.sources); } use ast::BinOp::*; match op { Add | Sub | Mul | Div => left.type_, Rem | And | Or | Xor | Shl | ShrU | ShrS | DivU | RemU => { if left.type_ != Some(I32) && left.type_ != Some(I64) { return type_mismatch( Some(I32), &left.span, left.type_, &left.span, context.sources, ); } else { left.type_ } } Eq | Ne | Lt | Le | Gt | Ge => Some(I32), LtU | LeU | GtU | GeU => { if left.type_ != Some(I32) && left.type_ != Some(I64) { return type_mismatch( Some(I32), &left.span, left.type_, &left.span, context.sources, ); } else { Some(I32) } } } } ast::Expr::Variable { ref name, ref mut local_id, } => { if let Some(id) = context.local_vars.get(name) { *local_id = Some(id); Some(context.locals[id].type_) } else if let Some(&Var { type_, .. }) = context.global_vars.get(name) { Some(type_) } else { return unknown_variable(&expr.span, context.sources); } } ast::Expr::Assign { ref name, ref mut value, ref mut local_id, } => { tc_expression(context, value)?; let (type_, span) = if let Some(id) = context.local_vars.get(name) { *local_id = Some(id); let local = &context.locals[id]; if local.index.is_none() { return immutable_assign(&expr.span, context.sources); } (local.type_, &local.span) } else if let Some(&Var { type_, ref span, mutable, }) = context.global_vars.get(name) { if !mutable { return immutable_assign(&expr.span, context.sources); } (type_, span) } else { return unknown_variable(&expr.span, context.sources); }; if value.type_ != Some(type_) { return type_mismatch(Some(type_), span, value.type_, &value.span, context.sources); } None } ast::Expr::LocalTee { ref name, ref mut value, ref mut local_id, } => { tc_expression(context, value)?; if let Some(id) = context.local_vars.get(name) { *local_id = Some(id); let local = &context.locals[id]; if local.index.is_none() { return immutable_assign(&expr.span, context.sources); } if value.type_ != Some(local.type_) { return type_mismatch( Some(local.type_), &local.span, value.type_, &value.span, context.sources, ); } Some(local.type_) } else { return unknown_variable(&expr.span, context.sources); } } ast::Expr::Loop { ref label, ref mut block, } => { context.block_stack.push(label.clone()); tc_expression(context, block)?; context.block_stack.pop(); block.type_ } ast::Expr::LabelBlock { ref label, ref mut block, } => { context.block_stack.push(label.clone()); tc_expression(context, block)?; context.block_stack.pop(); if block.type_ != None { // TODO: implement, requires branches to optionally provide values return type_mismatch(None, &expr.span, block.type_, &block.span, context.sources); } None } ast::Expr::Branch(ref label) => { if !context.block_stack.contains(label) { return missing_label(&expr.span, context.sources); } None } ast::Expr::BranchIf { ref mut condition, ref label, } => { tc_expression(context, condition)?; if condition.type_ != Some(I32) { return type_mismatch( Some(I32), &expr.span, condition.type_, &condition.span, context.sources, ); } if !context.block_stack.contains(label) { return missing_label(&expr.span, context.sources); } None } ast::Expr::Cast { ref mut value, type_, } => { tc_expression(context, value)?; if value.type_.is_none() { return expected_type(&expr.span, context.sources); } Some(type_) } ast::Expr::FuncCall { ref name, ref mut params, } => { for param in params.iter_mut() { tc_expression(context, param)?; if param.type_.is_none() { return expected_type(¶m.span, context.sources); } } if let Some(load) = context.intrinsics.find_load(name) { tc_memarg(context, params.as_mut_slice(), &expr.span)?; Some(load.type_) } else if let Some(store) = context.intrinsics.find_store(name) { if let Some(value) = params.first_mut() { tc_expression(context, value)?; if value.type_ != Some(store.type_) { type_mismatch( Some(store.type_), &expr.span, value.type_, &value.span, context.sources, )?; } } else { return report_error("Missing parameters", &expr.span, context.sources); } tc_memarg(context, &mut params[1..], &expr.span)?; None } else if let Some(type_map) = context .functions .get(name) .map(|fnc| HashMap::from_iter([(fnc.params.clone(), fnc.type_)])) .or_else(|| context.intrinsics.find_types(name)) { if let Some(rtype) = type_map.get(¶ms.iter().map(|p| p.type_.unwrap()).collect::>()) { *rtype } else { let mut report = Report::build(ReportKind::Error, expr.span.0, expr.span.1.start) .with_message("No matching function found"); for (params, rtype) in type_map { let param_str: Vec<_> = params.into_iter().map(|t| t.to_string()).collect(); let msg = format!( "Found {}({}){}", name, param_str.join(", "), if let Some(rtype) = rtype { format!(" -> {}", rtype) } else { String::new() } ); report = report.with_label(Label::new(expr.span.clone()).with_message(msg)); } report.finish().eprint(context.sources).unwrap(); return Err(()); } } else { return report_error( &format!("Unknown function {}", name), &expr.span, context.sources, ); } } ast::Expr::Select { ref mut condition, ref mut if_true, ref mut if_false, } => { tc_expression(context, condition)?; tc_expression(context, if_true)?; tc_expression(context, if_false)?; if condition.type_ != Some(ast::Type::I32) { return type_mismatch( Some(I32), &condition.span, condition.type_, &condition.span, context.sources, ); } if if_true.type_.is_some() { if if_true.type_ != if_false.type_ { return type_mismatch( if_true.type_, &if_true.span, if_false.type_, &if_false.span, context.sources, ); } } else { return expected_type(&if_true.span, context.sources); } if_true.type_ } ast::Expr::If { ref mut condition, ref mut if_true, ref mut if_false, } => { tc_expression(context, condition)?; tc_expression(context, if_true)?; if let Some(ref mut if_false) = if_false { tc_expression(context, if_false)?; if if_true.type_ != if_false.type_ { return type_mismatch( if_true.type_, &if_true.span, if_false.type_, &if_false.span, context.sources, ); } else { if_true.type_ } } else { None } } ast::Expr::Return { ref mut value } => { if let Some(ref mut value) = value { tc_expression(context, value)?; if value.type_ != context.return_type { return type_mismatch( context.return_type, &expr.span, value.type_, &value.span, context.sources, ); } } None } ast::Expr::First { ref mut value, ref mut drop, } => { tc_expression(context, value)?; tc_expression(context, drop)?; value.type_ } ast::Expr::Error => unreachable!(), }; Ok(()) } fn tc_mem_location<'a>( context: &mut Context<'a>, mem_location: &mut ast::MemoryLocation, ) -> Result<()> { tc_expression(context, &mut mem_location.left)?; tc_const(&mut mem_location.right, context.sources)?; if mem_location.left.type_ != Some(I32) { return type_mismatch( Some(I32), &mem_location.left.span, mem_location.left.type_, &mem_location.left.span, context.sources, ); } if mem_location.right.type_ != Some(I32) { return type_mismatch( Some(I32), &mem_location.right.span, mem_location.right.type_, &mem_location.right.span, context.sources, ); } Ok(()) } fn tc_const(expr: &mut ast::Expression, sources: &Sources) -> Result<()> { use ast::Expr::*; expr.type_ = Some(match expr.expr { I32Const(_) => I32, I64Const(_) => I64, F32Const(_) => F32, F64Const(_) => F64, _ => return report_error("Expected constant value", &expr.span, sources), }); Ok(()) } fn tc_memarg(context: &mut Context, params: &mut [ast::Expression], span: &Span) -> Result<()> { if params.is_empty() || params.len() > 3 { let msg = if params.is_empty() { "Missing base address parameter" } else { "Too many MemArg parameters" }; return report_error(msg, span, context.sources); } for (index, param) in params.iter_mut().enumerate() { tc_expression(context, param)?; if param.type_ != Some(I32) { return type_mismatch(Some(I32), &span, param.type_, ¶m.span, context.sources); } if index > 0 { tc_const(param, context.sources)?; } if index == 2 { let align = param.const_i32(); if align < 0 || align > 4 { return report_error( &format!("Alignment {} out of range (0-4)", align), ¶m.span, context.sources, ); } } } Ok(()) }