From d0ee844d8c9e8703c388f749b73d3214d75306db Mon Sep 17 00:00:00 2001 From: Dennis Ranke Date: Mon, 1 Nov 2021 09:21:36 +0100 Subject: [PATCH] implemented if --- Cargo.lock | 30 --------- Cargo.toml | 1 - src/ast.rs | 26 ++++---- src/constfold.rs | 31 ++++++---- src/emit.rs | 147 ++++++++++++++++++++++++++++++-------------- src/parser.rs | 156 +++++++++++++++++++++++++++++++++++++---------- src/typecheck.rs | 56 ++++++++++++----- 7 files changed, 297 insertions(+), 150 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index eaf9ea6..f786bd9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -87,7 +87,6 @@ dependencies = [ "anyhow", "ariadne", "chumsky", - "nom", "wasm-encoder", "wasmparser", ] @@ -110,29 +109,6 @@ version = "0.2.105" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "869d572136620d55835903746bcb5cdc54cb2851fd0aeec53220b4bb65ef3013" -[[package]] -name = "memchr" -version = "2.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "308cc39be01b73d0d18f82a0e7b2a3df85245f84af96fdddc5d202d27e47b86a" - -[[package]] -name = "minimal-lexical" -version = "0.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c64630dcdd71f1a64c435f54885086a0de5d6a12d104d69b165fb7d5286d677" - -[[package]] -name = "nom" -version = "7.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ffd9d26838a953b4af82cbeb9f1592c6798916983959be223a7124e992742c1" -dependencies = [ - "memchr", - "minimal-lexical", - "version_check", -] - [[package]] name = "proc-macro-hack" version = "0.5.19" @@ -148,12 +124,6 @@ dependencies = [ "crunchy", ] -[[package]] -name = "version_check" -version = "0.9.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5fecdca9a5291cc2b8dcf7dc02453fee791a280f3743cb0905f8822ae463b3fe" - [[package]] name = "wasi" version = "0.10.2+wasi-snapshot-preview1" diff --git a/Cargo.toml b/Cargo.toml index 8d70333..bb2a6af 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,6 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -nom = "7" wasmparser = "0.81" wasm-encoder = "0.8" anyhow = "1" diff --git a/src/ast.rs b/src/ast.rs index a11cb99..3865711 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -46,19 +46,7 @@ pub struct Function { pub name: String, pub params: Vec<(String, Type)>, pub type_: Option, - pub body: Block, -} - -#[derive(Debug)] -pub struct Block { - pub statements: Vec, - pub final_expression: Option>, -} - -impl Block { - pub fn type_(&self) -> Option { - self.final_expression.as_ref().and_then(|e| e.type_) - } + pub body: Expression, } #[derive(Debug)] @@ -78,6 +66,10 @@ pub struct Expression { #[derive(Debug)] pub enum Expr { + Block { + statements: Vec, + final_expression: Option>, + }, I32Const(i32), F32Const(f32), Variable(String), @@ -91,9 +83,10 @@ pub enum Expr { mem_location: MemoryLocation, value: Box, }, + Peek(MemoryLocation), Loop { label: String, - block: Box, + block: Box, }, BranchIf { condition: Box, @@ -125,6 +118,11 @@ pub enum Expr { if_true: Box, if_false: Box, }, + If { + condition: Box, + if_true: Box, + if_false: Option> + }, Error, } diff --git a/src/constfold.rs b/src/constfold.rs index d67c6ba..34f260b 100644 --- a/src/constfold.rs +++ b/src/constfold.rs @@ -2,16 +2,7 @@ use crate::ast; pub fn fold_script(script: &mut ast::Script) { for func in &mut script.functions { - fold_block(&mut func.body); - } -} - -fn fold_block(block: &mut ast::Block) { - for stmt in &mut block.statements { - fold_expr(stmt); - } - if let Some(ref mut expr) = block.final_expression { - fold_expr(expr); + fold_expr(&mut func.body); } } @@ -23,6 +14,14 @@ fn fold_mem_location(mem_location: &mut ast::MemoryLocation) { fn fold_expr(expr: &mut ast::Expression) { use ast::BinOp::*; match expr.expr { + ast::Expr::Block { ref mut statements, ref mut final_expression} => { + for stmt in statements { + fold_expr(stmt); + } + if let Some(ref mut expr) = final_expression { + fold_expr(expr); + } + } ast::Expr::Let { ref mut value, .. } => { if let Some(ref mut expr) = value { fold_expr(expr); @@ -36,6 +35,7 @@ fn fold_expr(expr: &mut ast::Expression) { fold_mem_location(mem_location); fold_expr(value); } + ast::Expr::Peek(ref mut mem_location) => fold_mem_location(mem_location), ast::Expr::UnaryOp { op, ref mut value } => { fold_expr(value); let result = match (op, &value.expr) { @@ -107,7 +107,7 @@ fn fold_expr(expr: &mut ast::Expression) { } ast::Expr::I32Const(_) | ast::Expr::F32Const(_) | ast::Expr::Variable { .. } => (), ast::Expr::LocalTee { ref mut value, .. } => fold_expr(value), - ast::Expr::Loop { ref mut block, .. } => fold_block(block), + ast::Expr::Loop { ref mut block, .. } => fold_expr(block), ast::Expr::BranchIf { ref mut condition, .. } => fold_expr(condition), @@ -137,6 +137,15 @@ fn fold_expr(expr: &mut ast::Expression) { fold_expr(if_true); fold_expr(if_false); } + ast::Expr::If { + ref mut condition, ref mut if_true, ref mut if_false + } => { + fold_expr(condition); + fold_expr(if_true); + if let Some(ref mut if_false) = if_false { + fold_expr(if_false); + } + } ast::Expr::Error => unreachable!() } } diff --git a/src/emit.rs b/src/emit.rs index e892ba6..7a4c807 100644 --- a/src/emit.rs +++ b/src/emit.rs @@ -116,7 +116,7 @@ struct FunctionContext<'a> { fn emit_function(func: &ast::Function, globals: &HashMap<&str, u32>) -> Function { let mut locals = Vec::new(); - collect_locals(&func.body, &mut locals); + collect_locals_expr(&func.body, &mut locals); locals.sort_by_key(|(_, t)| *t); let mut function = Function::new_with_locals_types(locals.iter().map(|(_, t)| map_type(*t))); @@ -135,8 +135,8 @@ fn emit_function(func: &ast::Function, globals: &HashMap<&str, u32>) -> Function deferred_inits: HashMap::new(), }; - emit_block(&mut context, &func.body); - if func.type_.is_none() && func.body.type_().is_some() { + emit_expression(&mut context, &func.body); + if func.type_.is_none() && func.body.type_.is_some() { function.instruction(&Instruction::Drop); } function.instruction(&Instruction::End); @@ -144,23 +144,28 @@ fn emit_function(func: &ast::Function, globals: &HashMap<&str, u32>) -> Function function } -fn collect_locals<'a>(block: &ast::Block, locals: &mut Vec<(String, ast::Type)>) { - for stmt in &block.statements { - collect_locals_expr(stmt, locals); - } - if let Some(ref expr) = block.final_expression { - collect_locals_expr(expr, locals); - } -} - fn collect_locals_expr<'a>(expr: &ast::Expression, locals: &mut Vec<(String, ast::Type)>) { match &expr.expr { - ast::Expr::Let {name, type_, value, ..} => { + ast::Expr::Block { + statements, + final_expression, + } => { + for stmt in statements { + collect_locals_expr(stmt, locals); + } + if let Some(ref expr) = final_expression { + collect_locals_expr(expr, locals); + } + } + ast::Expr::Let { + name, type_, value, .. + } => { locals.push((name.clone(), type_.unwrap())); if let Some(ref value) = value { collect_locals_expr(value, locals); } } + ast::Expr::Peek(mem_location) => collect_locals_expr(&mem_location.left, locals), ast::Expr::Poke { mem_location, value, @@ -177,7 +182,7 @@ fn collect_locals_expr<'a>(expr: &ast::Expression, locals: &mut Vec<(String, ast } ast::Expr::BranchIf { condition, .. } => collect_locals_expr(condition, locals), ast::Expr::LocalTee { value, .. } => collect_locals_expr(value, locals), - ast::Expr::Loop { block, .. } => collect_locals(block, locals), + ast::Expr::Loop { block, .. } => collect_locals_expr(block, locals), ast::Expr::Cast { value, .. } => collect_locals_expr(value, locals), ast::Expr::FuncCall { params, .. } => { for param in params { @@ -194,25 +199,60 @@ fn collect_locals_expr<'a>(expr: &ast::Expression, locals: &mut Vec<(String, ast collect_locals_expr(if_true, locals); collect_locals_expr(if_false, locals); } - ast::Expr::Error => unreachable!() + ast::Expr::If { + condition, + if_true, + if_false, + } => { + collect_locals_expr(condition, locals); + collect_locals_expr(if_true, locals); + if let Some(if_false) = if_false { + collect_locals_expr(if_false, locals); + } + } + ast::Expr::Error => unreachable!(), } } -fn emit_block<'a>(ctx: &mut FunctionContext<'a>, block: &'a ast::Block) { - for stmt in &block.statements { - emit_expression(ctx, stmt); - if stmt.type_.is_some() { - ctx.function.instruction(&Instruction::Drop); - } - } - if let Some(ref expr) = block.final_expression { - emit_expression(ctx, expr); +fn mem_arg_for_location(mem_location: &ast::MemoryLocation) -> MemArg { + let offset = if let ast::Expr::I32Const(v) = mem_location.right.expr { + v as u32 as u64 + } else { + unreachable!() + }; + match mem_location.size { + ast::MemSize::Byte => MemArg { + align: 0, + memory_index: 0, + offset, + }, + ast::MemSize::Word => MemArg { + align: 2, + memory_index: 0, + offset, + }, } } fn emit_expression<'a>(ctx: &mut FunctionContext<'a>, expr: &'a ast::Expression) { match &expr.expr { - ast::Expr::Let { value, name, defer, ..} => { + ast::Expr::Block { + statements, + final_expression, + } => { + for stmt in statements { + emit_expression(ctx, stmt); + if stmt.type_.is_some() { + ctx.function.instruction(&Instruction::Drop); + } + } + if let Some(ref expr) = final_expression { + emit_expression(ctx, expr); + } + } + ast::Expr::Let { + value, name, defer, .. + } => { if let Some(ref val) = value { if *defer { ctx.deferred_inits.insert(name, val); @@ -223,34 +263,29 @@ fn emit_expression<'a>(ctx: &mut FunctionContext<'a>, expr: &'a ast::Expression) } } } + ast::Expr::Peek(mem_location) => { + emit_expression(ctx, &mem_location.left); + let mem_arg = mem_arg_for_location(mem_location); + ctx.function.instruction(&match mem_location.size { + ast::MemSize::Byte => Instruction::I32Load8_U(mem_arg), + ast::MemSize::Word => Instruction::I32Load(mem_arg), + }); + } ast::Expr::Poke { mem_location, value, - .. } => { emit_expression(ctx, &mem_location.left); emit_expression(ctx, value); - let offset = if let ast::Expr::I32Const(v) = mem_location.right.expr { - v as u32 as u64 - } else { - unreachable!() - }; + let mem_arg = mem_arg_for_location(mem_location); ctx.function.instruction(&match mem_location.size { - ast::MemSize::Byte => Instruction::I32Store8(MemArg { - align: 0, - memory_index: 0, - offset, - }), - ast::MemSize::Word => Instruction::I32Store(MemArg { - align: 2, - memory_index: 0, - offset, - }), + ast::MemSize::Byte => Instruction::I32Store8(mem_arg), + ast::MemSize::Word => Instruction::I32Store(mem_arg), }); } ast::Expr::UnaryOp { op, value } => { - use ast::UnaryOp::*; use ast::Type::*; + use ast::UnaryOp::*; match (value.type_.unwrap(), op) { (I32, Negate) => { // TODO: try to improve this uglyness @@ -258,7 +293,7 @@ fn emit_expression<'a>(ctx: &mut FunctionContext<'a>, expr: &'a ast::Expression) emit_expression(ctx, value); ctx.function.instruction(&Instruction::I32Sub); } - _ => unreachable!() + _ => unreachable!(), }; } ast::Expr::BinOp { @@ -328,8 +363,8 @@ fn emit_expression<'a>(ctx: &mut FunctionContext<'a>, expr: &'a ast::Expression) ast::Expr::Loop { label, block, .. } => { ctx.labels.push(label.to_string()); ctx.function - .instruction(&Instruction::Loop(map_block_type(block.type_()))); - emit_block(ctx, block); + .instruction(&Instruction::Loop(map_block_type(block.type_))); + emit_expression(ctx, block); ctx.labels.pop(); ctx.function.instruction(&Instruction::End); } @@ -380,7 +415,27 @@ fn emit_expression<'a>(ctx: &mut FunctionContext<'a>, expr: &'a ast::Expression) emit_expression(ctx, condition); ctx.function.instruction(&Instruction::Select); } - ast::Expr::Error => unreachable!() + ast::Expr::If { + condition, + if_true, + if_false, + } => { + emit_expression(ctx, condition); + ctx.function + .instruction(&Instruction::If(map_block_type(expr.type_))); + emit_expression(ctx, if_true); + if if_true.type_.is_some() && if_true.type_ != expr.type_ { + ctx.function.instruction(&Instruction::Drop); + } + if let Some(if_false) = if_false { + ctx.function.instruction(&Instruction::Else); + emit_expression(ctx, if_false); + if if_false.type_.is_some() && if_false.type_ != expr.type_ { + ctx.function.instruction(&Instruction::Drop); + } + } + } + ast::Expr::Error => unreachable!(), } } diff --git a/src/parser.rs b/src/parser.rs index 8f20dfa..0dcef0a 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -1,8 +1,8 @@ +use crate::ast; +use crate::Span; use ariadne::{Color, Fmt, Label, Report, ReportKind, Source}; use chumsky::{prelude::*, stream::Stream}; use std::fmt; -use crate::ast; -use crate::Span; #[derive(Clone, Debug, PartialEq, Eq, Hash)] enum Token { @@ -18,6 +18,8 @@ enum Token { Defer, As, Select, + If, + Else, Ident(String), Str(String), Int(i32), @@ -41,6 +43,8 @@ impl fmt::Display for Token { Token::Defer => write!(f, "defer"), Token::As => write!(f, "as"), Token::Select => write!(f, "select"), + Token::If => write!(f, "if"), + Token::Else => write!(f, "else"), Token::Ident(s) => write!(f, "{}", s), Token::Str(s) => write!(f, "{:?}", s), Token::Int(v) => write!(f, "{}", v), @@ -191,6 +195,8 @@ fn lexer() -> impl Parser, Error = Simple> { "defer" => Token::Defer, "as" => Token::As, "select" => Token::Select, + "if" => Token::If, + "Else" => Token::Else, _ => Token::Ident(ident), }); @@ -229,7 +235,7 @@ fn map_token( }) } -fn block_parser() -> impl Parser> + Clone { +fn block_parser() -> impl Parser> + Clone { recursive(|block| { let mut block_expression = None; let expression = recursive(|expression| { @@ -257,7 +263,8 @@ fn block_parser() -> impl Parser> + Clo .map(|(name, expr)| ast::Expr::LocalTee { name, value: Box::new(expr), - }).boxed(); + }) + .boxed(); let loop_expr = just(Token::Loop) .ignore_then(ident) @@ -271,7 +278,29 @@ fn block_parser() -> impl Parser> + Clo block: Box::new(block), }); - let block_expr = loop_expr.boxed(); + let if_expr = just(Token::If) + .ignore_then(expression.clone()) + .then( + block + .clone() + .delimited_by(Token::Ctrl('{'), Token::Ctrl('}')), + ) + .then( + just(Token::Else) + .ignore_then( + block + .clone() + .delimited_by(Token::Ctrl('{'), Token::Ctrl('}')), + ) + .or_not(), + ) + .map(|((condition, if_true), if_false)| ast::Expr::If { + condition: Box::new(condition), + if_true: Box::new(if_true), + if_false: if_false.map(Box::new), + }); + + let block_expr = loop_expr.or(if_expr).boxed(); block_expression = Some(block_expr.clone()); @@ -282,7 +311,8 @@ fn block_parser() -> impl Parser> + Clo .map(|(condition, label)| ast::Expr::BranchIf { condition: Box::new(condition), label, - }).boxed(); + }) + .boxed(); let let_ = just(Token::Let) .ignore_then(just(Token::Defer).or_not()) @@ -298,7 +328,8 @@ fn block_parser() -> impl Parser> + Clo type_, value: value.map(Box::new), defer: defer.is_some(), - }).boxed(); + }) + .boxed(); let tee = ident .clone() @@ -307,7 +338,8 @@ fn block_parser() -> impl Parser> + Clo .map(|(name, value)| ast::Expr::LocalTee { name, value: Box::new(value), - }).boxed(); + }) + .boxed(); let select = just(Token::Select) .ignore_then( @@ -323,7 +355,8 @@ fn block_parser() -> impl Parser> + Clo condition: Box::new(condition), if_true: Box::new(if_true), if_false: Box::new(if_false), - }).boxed(); + }) + .boxed(); let function_call = ident .clone() @@ -333,7 +366,8 @@ fn block_parser() -> impl Parser> + Clo .separated_by(just(Token::Ctrl(','))) .delimited_by(Token::Ctrl('('), Token::Ctrl(')')), ) - .map(|(name, params)| ast::Expr::FuncCall { name, params }).boxed(); + .map(|(name, params)| ast::Expr::FuncCall { name, params }) + .boxed(); let atom = val .or(tee) @@ -353,8 +387,9 @@ fn block_parser() -> impl Parser> + Clo Token::Ctrl(')'), [(Token::Ctrl('{'), Token::Ctrl('}'))], |span| ast::Expr::Error.with_span(span), - )).boxed(); - + )) + .boxed(); + let unary_op = just(Token::Op("-".to_string())) .to(ast::UnaryOp::Negate) .map_with_span(|op, span| (op, span)) @@ -369,7 +404,8 @@ fn block_parser() -> impl Parser> + Clo } .with_span(span) }) - }).boxed(); + }) + .boxed(); let op_cast = unary_op .clone() @@ -385,22 +421,31 @@ fn block_parser() -> impl Parser> + Clo type_, } .with_span(span) - }).boxed(); + }) + .boxed(); let mem_size = just(Token::Ctrl('?')) .to(ast::MemSize::Byte) .or(just(Token::Ctrl('!')).to(ast::MemSize::Word)); - let memory_op = op_cast - .clone() - .then( - mem_size - .then(op_cast.clone()) - .then_ignore(just(Token::Op("=".to_string()))) - .then(expression.clone()) - .repeated(), - ) - .foldl(|left, ((size, right), value)| { + let mem_op = mem_size.then(op_cast.clone()); + + fn make_memory_op( + left: ast::Expression, + peek_ops: Vec<(ast::MemSize, ast::Expression)>, + poke_op: Option<((ast::MemSize, ast::Expression), ast::Expression)>, + ) -> ast::Expression { + let left = peek_ops.into_iter().fold(left, |left, (size, right)| { + let span = left.span.start..right.span.end; + ast::Expr::Peek(ast::MemoryLocation { + span: span.clone(), + left: Box::new(left), + size, + right: Box::new(right), + }) + .with_span(span) + }); + if let Some(((size, right), value)) = poke_op { let span = left.span.start..value.span.end; ast::Expr::Poke { mem_location: ast::MemoryLocation { @@ -412,7 +457,47 @@ fn block_parser() -> impl Parser> + Clo value: Box::new(value), } .with_span(span) - }).boxed(); + } else { + left + } + } + + let short_memory_op = mem_op + .clone() + .then( + just(Token::Op("=".to_string())) + .ignore_then(expression.clone()) + .or_not(), + ) + .map(|((size, left), value)| { + let right = ast::Expr::I32Const(0).with_span(left.span.clone()); + if let Some(value) = value { + make_memory_op(left, vec![], Some(((size, right), value))) + } else { + make_memory_op(left, vec![(size, right)], None) + } + }); + + let memory_op = op_cast + .clone() + .or(short_memory_op.clone()) + .then(mem_op.clone().repeated().at_least(1)) + .then( + just(Token::Op("=".to_string())) + .ignore_then(expression.clone()) + .or_not(), + ) + .map(|((left, mut peek_ops), poke_op)| { + if let Some(value) = poke_op { + let poke_op = Some((peek_ops.pop().unwrap(), value)); + make_memory_op(left, peek_ops, poke_op) + } else { + make_memory_op(left, peek_ops, None) + } + }) + .boxed() + .or(op_cast.clone()) + .or(short_memory_op.clone()); let op_product = memory_op .clone() @@ -432,7 +517,8 @@ fn block_parser() -> impl Parser> + Clo right: Box::new(right), } .with_span(span) - }).boxed(); + }) + .boxed(); let op_sum = op_product .clone() @@ -451,7 +537,8 @@ fn block_parser() -> impl Parser> + Clo right: Box::new(right), } .with_span(span) - }).boxed(); + }) + .boxed(); let op_cmp = op_sum .clone() @@ -474,7 +561,8 @@ fn block_parser() -> impl Parser> + Clo right: Box::new(right), } .with_span(span) - }).boxed(); + }) + .boxed(); let op_bit = op_cmp .clone() @@ -494,7 +582,8 @@ fn block_parser() -> impl Parser> + Clo right: Box::new(right), } .with_span(span) - }).boxed(); + }) + .boxed(); op_bit }); @@ -507,9 +596,12 @@ fn block_parser() -> impl Parser> + Clo .or(block_expression.map_with_span(|expr, span| expr.with_span(span))) .repeated() .then(expression.clone().or_not()) - .map(|(statements, final_expression)| ast::Block { - statements, - final_expression: final_expression.map(|e| Box::new(e)), + .map_with_span(|(statements, final_expression), span| { + ast::Expr::Block { + statements, + final_expression: final_expression.map(|e| Box::new(e)), + } + .with_span(span) }) }) } diff --git a/src/typecheck.rs b/src/typecheck.rs index e3ec209..6d9fb26 100644 --- a/src/typecheck.rs +++ b/src/typecheck.rs @@ -85,7 +85,7 @@ pub fn tc_script(script: &mut ast::Script, source: &str) -> Result<()> { } } - tc_block(&mut context, &mut f.body)?; + tc_expression(&mut context, &mut f.body)?; } result @@ -98,19 +98,6 @@ struct Context<'a> { block_stack: Vec, } -fn tc_block(context: &mut Context, block: &mut ast::Block) -> Result<()> { - let mut result = Ok(()); - for stmt in &mut block.statements { - if tc_expression(context, stmt).is_err() { - result = Err(()); - } - } - if let Some(ref mut expr) = block.final_expression { - tc_expression(context, expr)?; - } - result -} - fn report_duplicate_definition( msg: &str, span: &Span, @@ -195,6 +182,20 @@ fn unknown_variable(span: &Span, source: &str) -> Result<()> { fn tc_expression(context: &mut Context, expr: &mut ast::Expression) -> Result<()> { expr.type_ = match expr.expr { + ast::Expr::Block { + ref mut statements, + ref mut final_expression + } => { + for stmt in statements { + tc_expression(context, stmt)?; + } + if let Some(final_expression) = final_expression { + tc_expression(context, final_expression)?; + final_expression.type_ + } else { + None + } + } ast::Expr::Let { ref mut value, ref mut type_, @@ -254,6 +255,10 @@ fn tc_expression(context: &mut Context, expr: &mut ast::Expression) -> Result<() } None } + ast::Expr::Peek(ref mut mem_location) => { + tc_mem_location(context, mem_location)?; + Some(I32) + } ast::Expr::Poke { ref mut mem_location, ref mut value, @@ -346,9 +351,9 @@ fn tc_expression(context: &mut Context, expr: &mut ast::Expression) -> Result<() ref mut block, } => { context.block_stack.push(label.clone()); - tc_block(context, block)?; + tc_expression(context, block)?; context.block_stack.pop(); - block.final_expression.as_ref().and_then(|e| e.type_) + block.type_ } ast::Expr::BranchIf { ref mut condition, @@ -474,6 +479,25 @@ fn tc_expression(context: &mut Context, expr: &mut ast::Expression) -> Result<() } if_true.type_ } + ast::Expr::If { + ref mut condition, + ref mut if_true, + ref mut if_false + } => { + tc_expression(context, condition)?; + tc_expression(context, if_true)?; + if let Some(ref mut if_false) = if_false { + tc_expression(context, if_false)?; + if if_true.type_ != if_false.type_ { + // TODO: report type mismatch? + None + } else { + if_true.type_ + } + } else { + None + } + } ast::Expr::Error => unreachable!(), }; Ok(())