From d4a5d62255a7dba3fd307375947df4ec87f64231 Mon Sep 17 00:00:00 2001 From: Dennis Ranke Date: Sun, 24 Oct 2021 22:48:04 +0200 Subject: [PATCH] implement type checking and constand folding for simple example --- src/ast.rs | 20 ++++- src/constfold.rs | 77 +++++++++++++++++ src/main.rs | 8 +- src/parser.rs | 101 +++++++++++++--------- src/typecheck.rs | 217 +++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 377 insertions(+), 46 deletions(-) create mode 100644 src/constfold.rs create mode 100644 src/typecheck.rs diff --git a/src/ast.rs b/src/ast.rs index 81dd5a3..cbe2cee 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -3,7 +3,8 @@ pub struct Position(pub usize); #[derive(Debug)] pub struct Script<'a> { - pub items: Vec>, + pub global_vars: Vec>, + pub functions: Vec> } #[derive(Debug)] @@ -40,6 +41,7 @@ pub struct Block<'a> { pub enum Statement<'a> { LocalVariable(LocalVariable<'a>), Poke { + position: Position, mem_location: MemoryLocation<'a>, value: Expression<'a>, }, @@ -63,7 +65,19 @@ pub struct LocalVariable<'a> { } #[derive(Debug)] -pub enum Expression<'a> { +pub struct Expression<'a> { + pub type_: Option, + pub expr: Expr<'a>, +} + +impl<'a> From> for Expression<'a> { + fn from(expr: Expr<'a>) -> Expression<'a> { + Expression { type_: None, expr } + } +} + +#[derive(Debug)] +pub enum Expr<'a> { I32Const(i32), Variable { position: Position, @@ -92,7 +106,7 @@ pub enum Expression<'a> { }, } -#[derive(Debug)] +#[derive(Debug, Clone, Copy)] pub enum BinOp { Add, Sub, diff --git a/src/constfold.rs b/src/constfold.rs new file mode 100644 index 0000000..94026df --- /dev/null +++ b/src/constfold.rs @@ -0,0 +1,77 @@ +use crate::ast; + +pub fn fold_script(script: &mut ast::Script) { + for func in &mut script.functions { + fold_block(&mut func.body); + } +} + +fn fold_block(block: &mut ast::Block) { + for stmt in &mut block.statements { + match stmt { + ast::Statement::LocalVariable(lv) => { + if let Some(ref mut expr) = lv.value { + fold_expr(expr); + } + } + ast::Statement::Expression(expr) => fold_expr(expr), + ast::Statement::Poke { + mem_location, + value, + .. + } => { + fold_mem_location(mem_location); + fold_expr(value); + } + } + } + if let Some(ref mut expr) = block.final_expression { + fold_expr(expr); + } +} + +fn fold_mem_location(mem_location: &mut ast::MemoryLocation) { + fold_expr(&mut mem_location.left); + fold_expr(&mut mem_location.right); +} + +fn fold_expr(expr: &mut ast::Expression) { + use ast::BinOp::*; + match expr.expr { + ast::Expr::BinOp { + ref mut left, op, ref mut right, .. + } => { + fold_expr(left); + fold_expr(right); + dbg!(&left.expr, &right.expr); + match (&left.expr, &right.expr) { + (&ast::Expr::I32Const(left), &ast::Expr::I32Const(right)) => { + let result = match op { + Add => left.wrapping_add(right), + Sub => left.wrapping_sub(right), + Mul => left.wrapping_mul(right), + Div => left / right, // TODO: protect agains division by zero + Rem => left % right, // TODO: check correct behavior with negative operands + And => left & right, + Or => left | right, + Xor => left ^ right, + Eq => (left == right) as i32, + Ne => (left != right) as i32, + Lt => (left < right) as i32, + Le => (left <= right) as i32, + Gt => (left > right) as i32, + Ge => (left >= right) as i32, + }; + expr.expr = ast::Expr::I32Const(result); + } + _ => () + } + } + ast::Expr::I32Const(_) | ast::Expr::Variable { .. } => (), + ast::Expr::LocalTee { ref mut value, .. } => fold_expr(value), + ast::Expr::Loop { ref mut block, .. } => fold_block(block), + ast::Expr::BranchIf { + ref mut condition, .. + } => fold_expr(condition), + } +} diff --git a/src/main.rs b/src/main.rs index 2ce1c7c..de0a3c6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,11 +1,17 @@ mod parser; mod ast; +mod typecheck; +mod constfold; fn main() { let input = include_str!("../test.hw"); let result = parser::parse(input); match result { - Ok(script) => {dbg!(script);}, + Ok(mut script) => { + constfold::fold_script(&mut script); + typecheck::tc_script(&mut script).unwrap(); + dbg!(script); + }, Err(err) => println!("error: {}", nom::error::convert_error(input, err)) } } diff --git a/src/parser.rs b/src/parser.rs index c3c8658..f3f349d 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -19,7 +19,21 @@ pub fn parse(s: &str) -> Result> { fn script(s: &str) -> IResult { let (s, items) = many0(top_level_item)(s)?; - Ok((s, ast::Script { items })) + let mut global_vars = vec![]; + let mut functions = vec![]; + for item in items { + match item { + ast::TopLevelItem::GlobalVar(v) => global_vars.push(v), + ast::TopLevelItem::Function(f) => functions.push(f), + } + } + Ok(( + s, + ast::Script { + global_vars, + functions, + }, + )) } fn top_level_item(s: &str) -> IResult { @@ -88,7 +102,7 @@ fn block(s: &str) -> IResult { s, ast::Block { statements, - final_expression, + final_expression: final_expression.map(|e| e.into()), }, )) } @@ -96,22 +110,25 @@ fn block(s: &str) -> IResult { fn statement(s: &str) -> IResult { alt(( map(local_var, ast::Statement::LocalVariable), - map( - terminated(expression, ws(char(';'))), - ast::Statement::Expression, - ), + map(terminated(expression, ws(char(';'))), |e| { + ast::Statement::Expression(e.into()) + }), map( terminated(block_expression, not(peek(ws(char('}'))))), - ast::Statement::Expression, + |e| ast::Statement::Expression(e.into()), ), map( terminated( - pair(mem_location, preceded(ws(char('=')), expression)), + pair( + mem_location, + ws(pair(position, preceded(char('='), expression))), + ), ws(char(';')), ), - |(mem_location, value)| ast::Statement::Poke { + |(mem_location, (position, value))| ast::Statement::Poke { + position, mem_location, - value, + value: value.into(), }, ), ))(s) @@ -131,7 +148,7 @@ fn local_var(s: &str) -> IResult { position, name: name, type_, - value, + value: value.map(|v| v.into()), }, )) } @@ -151,31 +168,31 @@ fn mem_location(s: &str) -> IResult { ast::MemoryLocation { position, size, - left, - right, + left: left.into(), + right: right.into(), }, )) } -fn expression(s: &str) -> IResult { +fn expression(s: &str) -> IResult { expression_cmp(s) } -fn expression_atom(s: &str) -> IResult { +fn expression_atom(s: &str) -> IResult { alt(( branch_if, block_expression, map( separated_pair(pair(ws(position), identifier), ws(tag(":=")), expression), - |((position, name), value)| ast::Expression::LocalTee { + |((position, name), value)| ast::Expr::LocalTee { position, name: name, - value: Box::new(value), + value: Box::new(value.into()), }, ), - map(integer, |v| ast::Expression::I32Const(v)), + map(integer, |v| ast::Expr::I32Const(v)), map(ws(pair(position, identifier)), |(position, name)| { - ast::Expression::Variable { + ast::Expr::Variable { position, name: name, } @@ -184,7 +201,7 @@ fn expression_atom(s: &str) -> IResult { ))(s) } -fn branch_if(s: &str) -> IResult { +fn branch_if(s: &str) -> IResult { let (s, position) = ws(position)(s)?; let (s, _) = tag("branch_if")(s)?; cut(move |s| { @@ -194,16 +211,16 @@ fn branch_if(s: &str) -> IResult { Ok(( s, - ast::Expression::BranchIf { + ast::Expr::BranchIf { position, - condition: Box::new(condition), + condition: Box::new(condition.into()), label: label, }, )) })(s) } -fn expression_product(s: &str) -> IResult { +fn expression_product(s: &str) -> IResult { let (s, mut init) = map(expression_atom, Some)(s)?; fold_many0( pair( @@ -218,17 +235,17 @@ fn expression_product(s: &str) -> IResult { '%' => ast::BinOp::Rem, _ => unreachable!(), }; - ast::Expression::BinOp { + ast::Expr::BinOp { position, op, - left: Box::new(left), - right: Box::new(right), + left: Box::new(left.into()), + right: Box::new(right.into()), } }, )(s) } -fn expression_sum(s: &str) -> IResult { +fn expression_sum(s: &str) -> IResult { let (s, mut init) = map(expression_product, Some)(s)?; fold_many0( pair( @@ -242,17 +259,17 @@ fn expression_sum(s: &str) -> IResult { } else { ast::BinOp::Sub }; - ast::Expression::BinOp { + ast::Expr::BinOp { position, op, - left: Box::new(left), - right: Box::new(right), + left: Box::new(left.into()), + right: Box::new(right.into()), } }, )(s) } -fn expression_bit(s: &str) -> IResult { +fn expression_bit(s: &str) -> IResult { let (s, mut init) = map(expression_sum, Some)(s)?; fold_many0( pair( @@ -267,17 +284,17 @@ fn expression_bit(s: &str) -> IResult { '^' => ast::BinOp::Xor, _ => unreachable!(), }; - ast::Expression::BinOp { + ast::Expr::BinOp { position, op, - left: Box::new(left), - right: Box::new(right), + left: Box::new(left.into()), + right: Box::new(right.into()), } }, )(s) } -fn expression_cmp(s: &str) -> IResult { +fn expression_cmp(s: &str) -> IResult { let (s, mut init) = map(expression_bit, Some)(s)?; fold_many0( pair( @@ -305,21 +322,21 @@ fn expression_cmp(s: &str) -> IResult { ">" => ast::BinOp::Gt, _ => unreachable!(), }; - ast::Expression::BinOp { + ast::Expr::BinOp { position, op, - left: Box::new(left), - right: Box::new(right), + left: Box::new(left.into()), + right: Box::new(right.into()), } }, )(s) } -fn block_expression(s: &str) -> IResult { +fn block_expression(s: &str) -> IResult { loop_(s) } -fn loop_(s: &str) -> IResult { +fn loop_(s: &str) -> IResult { let (s, position) = ws(position)(s)?; let (s, _) = tag("loop")(s)?; cut(move |s| { @@ -328,10 +345,10 @@ fn loop_(s: &str) -> IResult { Ok(( s, - ast::Expression::Loop { + ast::Expr::Loop { position, label: label, - block: Box::new(block), + block: Box::new(block.into()), }, )) })(s) diff --git a/src/typecheck.rs b/src/typecheck.rs new file mode 100644 index 0000000..2ef6f42 --- /dev/null +++ b/src/typecheck.rs @@ -0,0 +1,217 @@ +use std::collections::HashMap; + +use crate::ast; +use ast::Type::*; + +#[derive(Debug)] +pub struct Error { + pub position: ast::Position, + pub message: String, +} + +type Result = std::result::Result; + +type Vars<'a> = HashMap<&'a str, ast::Type>; + +pub fn tc_script(script: &mut ast::Script) -> Result<()> { + let mut context = Context { + global_vars: HashMap::new(), + local_vars: HashMap::new(), + }; + + for v in &script.global_vars { + if context.global_vars.contains_key(v.name) { + return Err(Error { + position: v.position, + message: "Duplicate global variable".into(), + }); + } + context.global_vars.insert(v.name, v.type_); + } + + for f in &mut script.functions { + context.local_vars.clear(); + for (name, type_) in &f.params { + if context.local_vars.contains_key(name) || context.global_vars.contains_key(name) { + return Err(Error { + position: f.position, + message: format!("Variable already defined '{}'", name), + }); + } + context.local_vars.insert(name, *type_); + } + + tc_block(&mut context, &mut f.body)?; + } + + Ok(()) +} + +struct Context<'a> { + global_vars: Vars<'a>, + local_vars: Vars<'a>, +} + +fn tc_block<'a>(context: &mut Context<'a>, block: &mut ast::Block<'a>) -> Result<()> { + for stmt in &mut block.statements { + match *stmt { + ast::Statement::Expression(ref mut expr) => tc_expression(context, expr)?, + ast::Statement::LocalVariable(ref mut lv) => { + if let Some(ref mut value) = lv.value { + tc_expression(context, value)?; + if lv.type_.is_none() { + lv.type_ = value.type_; + } else if lv.type_ != value.type_ { + return Err(Error { + position: lv.position, + message: "Mismatched types".into(), + }); + } + } + if let Some(type_) = lv.type_ { + if context.local_vars.contains_key(lv.name) + || context.global_vars.contains_key(lv.name) + { + return Err(Error { + position: lv.position, + message: format!("Variable '{}' already defined", lv.name), + }); + } + context.local_vars.insert(lv.name, type_); + } else { + return Err(Error { + position: lv.position, + message: "Missing type".into(), + }); + } + } + ast::Statement::Poke { + position, + ref mut mem_location, + ref mut value, + } => { + tc_mem_location(context, mem_location)?; + tc_expression(context, value)?; + if value.type_ != Some(I32) { + return Err(Error { + position, + message: "Type mismatch".into(), + }); + } + } + } + } + if let Some(ref mut expr) = block.final_expression { + tc_expression(context, expr)?; + } + Ok(()) +} + +fn tc_expression<'a>(context: &mut Context<'a>, expr: &mut ast::Expression<'a>) -> Result<()> { + expr.type_ = match expr.expr { + ast::Expr::I32Const(_) => Some(ast::Type::I32), + ast::Expr::BinOp { + position, + op, + ref mut left, + ref mut right, + } => { + tc_expression(context, left)?; + tc_expression(context, right)?; + if left.type_.is_none() || left.type_ != right.type_ { + return Err(Error { + position, + message: "Type mismatch".into(), + }); + } + use ast::BinOp::*; + match op { + Add | Sub | Mul | Div => left.type_, + Rem | And | Or | Xor => { + if left.type_ != Some(I32) { + return Err(Error { + position, + message: "Unsupported type".into(), + }); + } else { + left.type_ + } + } + Eq | Ne | Lt | Le | Gt | Ge => Some(I32), + } + } + ast::Expr::Variable { position, name } => { + if let Some(&type_) = context + .global_vars + .get(name) + .or_else(|| context.local_vars.get(name)) + { + Some(type_) + } else { + return Err(Error { + position, + message: "Variable not found".into(), + }); + } + } + ast::Expr::LocalTee { + position, + name, + ref mut value, + } => { + tc_expression(context, value)?; + if let Some(&type_) = context.local_vars.get(name) { + if value.type_ != Some(type_) { + return Err(Error { + position, + message: "Type mismatch".into(), + }); + } + Some(type_) + } else { + return Err(Error { + position, + message: format!("No local variable '{}' found", name), + }); + } + } + ast::Expr::Loop { + position: _, + label: _, + ref mut block, + } => { + tc_block(context, block)?; + block.final_expression.as_ref().and_then(|e| e.type_) + } + ast::Expr::BranchIf { + position, + ref mut condition, + label: _, + } => { + tc_expression(context, condition)?; + if condition.type_ != Some(I32) { + return Err(Error { + position, + message: "Condition has to be i32".into(), + }); + } + None + } + }; + Ok(()) +} + +fn tc_mem_location<'a>( + context: &mut Context<'a>, + mem_location: &mut ast::MemoryLocation<'a>, +) -> Result<()> { + tc_expression(context, &mut mem_location.left)?; + tc_expression(context, &mut mem_location.right)?; + if mem_location.left.type_ != Some(I32) || mem_location.right.type_ != Some(I32) { + return Err(Error { + position: mem_location.position, + message: "Type mismatch".into(), + }); + } + Ok(()) +}