From 2267eed21cc5c3bf91af13e5dd726654f9bec568 Mon Sep 17 00:00:00 2001 From: Dennis Ranke Date: Wed, 27 Oct 2021 22:18:34 +0200 Subject: [PATCH] got trainride working --- src/ast.rs | 11 ++++++ src/constfold.rs | 20 +++++++++++ src/emit.rs | 45 ++++++++++++++++++++++++ src/main.rs | 9 ++++- src/parser.rs | 90 ++++++++++++++++++++++++++++++++---------------- src/typecheck.rs | 79 +++++++++++++++++++++++++++++++++++++++++- trainride.hw | 26 ++++++++++++++ 7 files changed, 248 insertions(+), 32 deletions(-) create mode 100644 trainride.hw diff --git a/src/ast.rs b/src/ast.rs index dc52134..67b91e8 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -136,6 +136,17 @@ pub enum Expr<'a> { value: Box>, type_: Type, }, + FuncCall { + position: Position, + name: &'a str, + params: Vec> + }, + Select { + position: Position, + condition: Box>, + if_true: Box>, + if_false: Box> + } } #[derive(Debug, Clone, Copy)] diff --git a/src/constfold.rs b/src/constfold.rs index fd2c216..ad3d232 100644 --- a/src/constfold.rs +++ b/src/constfold.rs @@ -104,5 +104,25 @@ fn fold_expr(expr: &mut ast::Expression) { ref mut condition, .. } => fold_expr(condition), ast::Expr::Cast { ref mut value, .. } => fold_expr(value), + ast::Expr::FuncCall { + name, + ref mut params, + .. + } => { + for param in params.iter_mut() { + fold_expr(param); + } + use ast::Expr::*; + let params: Vec<_> = params.iter().map(|e| &e.expr).collect(); + expr.expr = match (name, params.as_slice()) { + ("sqrt", [F32Const(v)]) if *v >= 0.0 => F32Const(v.sqrt()), + _ => return, + }; + } + ast::Expr::Select { ref mut condition, ref mut if_true, ref mut if_false, .. } => { + fold_expr(condition); + fold_expr(if_true); + fold_expr(if_false); + } } } diff --git a/src/emit.rs b/src/emit.rs index 9e2974c..ea88a0c 100644 --- a/src/emit.rs +++ b/src/emit.rs @@ -180,6 +180,21 @@ fn collect_locals_expr<'a>(expr: &ast::Expression<'a>, locals: &mut Vec<(&'a str ast::Expr::LocalTee { value, .. } => collect_locals_expr(value, locals), ast::Expr::Loop { block, .. } => collect_locals(block, locals), ast::Expr::Cast { value, .. } => collect_locals_expr(value, locals), + ast::Expr::FuncCall { params, .. } => { + for param in params { + collect_locals_expr(param, locals); + } + } + ast::Expr::Select { + condition, + if_true, + if_false, + .. + } => { + collect_locals_expr(condition, locals); + collect_locals_expr(if_true, locals); + collect_locals_expr(if_false, locals); + } } } @@ -336,6 +351,26 @@ fn emit_expression<'a>(ctx: &mut FunctionContext<'a>, expr: &'a ast::Expression) ctx.function.instruction(&inst); } } + ast::Expr::FuncCall { name, params, .. } => { + let mut types = vec![]; + for param in params { + types.push(param.type_.unwrap()); + emit_expression(ctx, param); + } + ctx.function + .instruction(&builtin_function(name, &types).unwrap()); + } + ast::Expr::Select { + condition, + if_true, + if_false, + .. + } => { + emit_expression(ctx, if_true); + emit_expression(ctx, if_false); + emit_expression(ctx, condition); + ctx.function.instruction(&Instruction::Select); + } } } @@ -355,3 +390,13 @@ fn map_block_type(t: Option) -> BlockType { BlockType::Empty } } + +fn builtin_function(name: &str, params: &[ast::Type]) -> Option> { + use ast::Type::*; + let inst = match (name, params) { + ("sqrt", &[F32]) => Instruction::F32Sqrt, + ("abs", &[F32]) => Instruction::F32Abs, + _ => return None, + }; + Some(inst) +} diff --git a/src/main.rs b/src/main.rs index 9f5bf75..7d232c6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -28,7 +28,14 @@ fn main() -> Result<()> { }; constfold::fold_script(&mut script); - typecheck::tc_script(&mut script).unwrap(); + if let Err(err) = typecheck::tc_script(&mut script) { + let line = input[..(input.len() - err.position.0)] + .chars() + .filter(|c| *c == '\n') + .count() + + 1; + bail!("{} in line {}", err.message, line); + } let wasm = emit::emit(&script); wasmparser::validate(&wasm)?; diff --git a/src/parser.rs b/src/parser.rs index 14b28fe..9311171 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -167,7 +167,7 @@ fn statement(s: &str) -> IResult { terminated( pair( mem_location, - ws(pair(position, preceded(char('='), expression))), + cut(ws(pair(position, preceded(char('='), expression)))), ), ws(char(';')), ), @@ -197,7 +197,7 @@ fn local_var(s: &str) -> IResult { name: name, type_, value: value.map(|v| v.into()), - defer: defer.is_some() + defer: defer.is_some(), }, )) })(s) @@ -225,7 +225,7 @@ fn mem_location(s: &str) -> IResult { } fn expression(s: &str) -> IResult { - expression_cmp(s) + expression_bit(s) } fn expression_atom(s: &str) -> IResult { @@ -242,6 +242,36 @@ fn expression_atom(s: &str) -> IResult { ), map(float, ast::Expr::F32Const), map(integer, ast::Expr::I32Const), + map( + tuple(( + terminated(ws(position), tag("select")), + preceded(ws(char('(')), expression), + preceded(ws(char(',')), expression), + delimited(ws(char(',')), expression, ws(char(')'))), + )), + |(position, condition, if_true, if_false)| ast::Expr::Select { + position, + condition: Box::new(condition.into()), + if_true: Box::new(if_true.into()), + if_false: Box::new(if_false.into()), + }, + ), + map( + tuple(( + ws(position), + identifier, + delimited( + ws(char('(')), + separated_list0(ws(char(',')), expression), + ws(char(')')), + ), + )), + |(position, name, params)| ast::Expr::FuncCall { + position, + name, + params: params.into_iter().map(|p| p.into()).collect(), + }, + ), map(ws(pair(position, identifier)), |(position, name)| { ast::Expr::Variable { position, @@ -333,33 +363,8 @@ fn expression_sum(s: &str) -> IResult { )(s) } -fn expression_bit(s: &str) -> IResult { - let (s, mut init) = map(expression_sum, Some)(s)?; - fold_many0( - pair( - ws(pair(position, alt((char('&'), char('|'), char('^'))))), - expression_sum, - ), - move || init.take().unwrap(), - |left, ((position, op), right)| { - let op = match op { - '&' => ast::BinOp::And, - '|' => ast::BinOp::Or, - '^' => ast::BinOp::Xor, - _ => unreachable!(), - }; - ast::Expr::BinOp { - position, - op, - left: Box::new(left.into()), - right: Box::new(right.into()), - } - }, - )(s) -} - fn expression_cmp(s: &str) -> IResult { - let (s, mut init) = map(expression_bit, Some)(s)?; + let (s, mut init) = map(expression_sum, Some)(s)?; fold_many0( pair( ws(pair( @@ -373,7 +378,7 @@ fn expression_cmp(s: &str) -> IResult { tag(">"), )), )), - expression_bit, + expression_sum, ), move || init.take().unwrap(), |left, ((position, op), right)| { @@ -396,6 +401,31 @@ fn expression_cmp(s: &str) -> IResult { )(s) } +fn expression_bit(s: &str) -> IResult { + let (s, mut init) = map(expression_cmp, Some)(s)?; + fold_many0( + pair( + ws(pair(position, alt((char('&'), char('|'), char('^'))))), + expression_cmp, + ), + move || init.take().unwrap(), + |left, ((position, op), right)| { + let op = match op { + '&' => ast::BinOp::And, + '|' => ast::BinOp::Or, + '^' => ast::BinOp::Xor, + _ => unreachable!(), + }; + ast::Expr::BinOp { + position, + op, + left: Box::new(left.into()), + right: Box::new(right.into()), + } + }, + )(s) +} + fn block_expression(s: &str) -> IResult { loop_(s) } diff --git a/src/typecheck.rs b/src/typecheck.rs index a239082..7111e4b 100644 --- a/src/typecheck.rs +++ b/src/typecheck.rs @@ -167,7 +167,7 @@ fn tc_expression<'a>(context: &mut Context<'a>, expr: &mut ast::Expression<'a>) } else { return Err(Error { position, - message: "Variable not found".into(), + message: format!("Variable '{}' not found", name), }); } } @@ -228,6 +228,73 @@ fn tc_expression<'a>(context: &mut Context<'a>, expr: &mut ast::Expression<'a>) } Some(type_) } + ast::Expr::FuncCall { + position, + name, + ref mut params, + } => { + if let Some((ptypes, rtype)) = builtin_function_types(name) { + if params.len() != ptypes.len() { + return Err(Error { + position, + message: format!( + "Expected {} parameters but found {}", + ptypes.len(), + params.len() + ), + }); + } + for (index, (ptype, param)) in ptypes.iter().zip(params.iter_mut()).enumerate() { + tc_expression(context, param)?; + if param.type_.is_none() || param.type_.unwrap() != *ptype { + return Err(Error { + position, + message: format!( + "Param {} is {:?} but should be {:?}", + index + 1, + param.type_, + ptype + ), + }); + } + } + rtype + } else { + return Err(Error { + position, + message: format!("Unknown function '{}'", name), + }); + } + } + ast::Expr::Select { + position, + ref mut condition, + ref mut if_true, + ref mut if_false, + } => { + tc_expression(context, condition)?; + tc_expression(context, if_true)?; + tc_expression(context, if_false)?; + if condition.type_ != Some(ast::Type::I32) { + return Err(Error { + position, + message: "Condition of select has to be of type i32".into(), + }); + } + if if_true.type_ != if_false.type_ { + return Err(Error { + position, + message: "Types of select branches differ".into(), + }); + } + if if_true.type_.is_none() { + return Err(Error { + position, + message: "Types of select branches cannot be void".into(), + }); + } + if_true.type_ + } }; Ok(()) } @@ -246,3 +313,13 @@ fn tc_mem_location<'a>( } Ok(()) } + +fn builtin_function_types(name: &str) -> Option<(&'static [ast::Type], Option)> { + use ast::Type::*; + let types: (&'static [ast::Type], Option) = match name { + "sqrt" => (&[F32], Some(F32)), + "abs" => (&[F32], Some(F32)), + _ => return None, + }; + Some(types) +} diff --git a/trainride.hw b/trainride.hw new file mode 100644 index 0000000..3e7cb7b --- /dev/null +++ b/trainride.hw @@ -0,0 +1,26 @@ +import "uw8.ram" memory(2); +import "uw8.time" global mut time: i32; + +export fn tic() { + let i: i32; + let defer t = time as f32 / 1000 as f32; + loop pixels { + let defer x = (i % 320 - 160) as f32; + let defer y = (i / 320) as f32 - 128.5; + let defer z = t + 20 as f32 / sqrt(x*x + y*y); + let defer z_int = z as i32; + let defer q = select(z_int % 9 >= 6, z, (z_int - z_int % 9 + 6) as f32); + let defer w = 9 as f32 / y + t; + let defer s = q - t; + let defer m = x * s / 50 as f32; + + i?120 = select(y > 0 as f32 & w < q, + select(abs(x * (w - t)) < 9 as f32, 15, 7) - w as i32 % 2, + select(y * s > -99 as f32 / (m * m + 1 as f32), + select(q == z, z_int % 2, 3), + 12 + (y / 23 as f32) as i32 + ) + ) * 16; + branch_if (i := i + 1) < 320*256: pixels; + } +}