got trainride working

This commit is contained in:
2021-10-27 22:18:34 +02:00
parent 23d926dbb3
commit 2267eed21c
7 changed files with 248 additions and 32 deletions

View File

@@ -136,6 +136,17 @@ pub enum Expr<'a> {
value: Box<Expression<'a>>,
type_: Type,
},
FuncCall {
position: Position,
name: &'a str,
params: Vec<Expression<'a>>
},
Select {
position: Position,
condition: Box<Expression<'a>>,
if_true: Box<Expression<'a>>,
if_false: Box<Expression<'a>>
}
}
#[derive(Debug, Clone, Copy)]

View File

@@ -104,5 +104,25 @@ fn fold_expr(expr: &mut ast::Expression) {
ref mut condition, ..
} => fold_expr(condition),
ast::Expr::Cast { ref mut value, .. } => fold_expr(value),
ast::Expr::FuncCall {
name,
ref mut params,
..
} => {
for param in params.iter_mut() {
fold_expr(param);
}
use ast::Expr::*;
let params: Vec<_> = params.iter().map(|e| &e.expr).collect();
expr.expr = match (name, params.as_slice()) {
("sqrt", [F32Const(v)]) if *v >= 0.0 => F32Const(v.sqrt()),
_ => return,
};
}
ast::Expr::Select { ref mut condition, ref mut if_true, ref mut if_false, .. } => {
fold_expr(condition);
fold_expr(if_true);
fold_expr(if_false);
}
}
}

View File

@@ -180,6 +180,21 @@ fn collect_locals_expr<'a>(expr: &ast::Expression<'a>, locals: &mut Vec<(&'a str
ast::Expr::LocalTee { value, .. } => collect_locals_expr(value, locals),
ast::Expr::Loop { block, .. } => collect_locals(block, locals),
ast::Expr::Cast { value, .. } => collect_locals_expr(value, locals),
ast::Expr::FuncCall { params, .. } => {
for param in params {
collect_locals_expr(param, locals);
}
}
ast::Expr::Select {
condition,
if_true,
if_false,
..
} => {
collect_locals_expr(condition, locals);
collect_locals_expr(if_true, locals);
collect_locals_expr(if_false, locals);
}
}
}
@@ -336,6 +351,26 @@ fn emit_expression<'a>(ctx: &mut FunctionContext<'a>, expr: &'a ast::Expression)
ctx.function.instruction(&inst);
}
}
ast::Expr::FuncCall { name, params, .. } => {
let mut types = vec![];
for param in params {
types.push(param.type_.unwrap());
emit_expression(ctx, param);
}
ctx.function
.instruction(&builtin_function(name, &types).unwrap());
}
ast::Expr::Select {
condition,
if_true,
if_false,
..
} => {
emit_expression(ctx, if_true);
emit_expression(ctx, if_false);
emit_expression(ctx, condition);
ctx.function.instruction(&Instruction::Select);
}
}
}
@@ -355,3 +390,13 @@ fn map_block_type(t: Option<ast::Type>) -> BlockType {
BlockType::Empty
}
}
fn builtin_function(name: &str, params: &[ast::Type]) -> Option<Instruction<'static>> {
use ast::Type::*;
let inst = match (name, params) {
("sqrt", &[F32]) => Instruction::F32Sqrt,
("abs", &[F32]) => Instruction::F32Abs,
_ => return None,
};
Some(inst)
}

View File

@@ -28,7 +28,14 @@ fn main() -> Result<()> {
};
constfold::fold_script(&mut script);
typecheck::tc_script(&mut script).unwrap();
if let Err(err) = typecheck::tc_script(&mut script) {
let line = input[..(input.len() - err.position.0)]
.chars()
.filter(|c| *c == '\n')
.count()
+ 1;
bail!("{} in line {}", err.message, line);
}
let wasm = emit::emit(&script);
wasmparser::validate(&wasm)?;

View File

@@ -167,7 +167,7 @@ fn statement(s: &str) -> IResult<ast::Statement> {
terminated(
pair(
mem_location,
ws(pair(position, preceded(char('='), expression))),
cut(ws(pair(position, preceded(char('='), expression)))),
),
ws(char(';')),
),
@@ -197,7 +197,7 @@ fn local_var(s: &str) -> IResult<ast::LocalVariable> {
name: name,
type_,
value: value.map(|v| v.into()),
defer: defer.is_some()
defer: defer.is_some(),
},
))
})(s)
@@ -225,7 +225,7 @@ fn mem_location(s: &str) -> IResult<ast::MemoryLocation> {
}
fn expression(s: &str) -> IResult<ast::Expr> {
expression_cmp(s)
expression_bit(s)
}
fn expression_atom(s: &str) -> IResult<ast::Expr> {
@@ -242,6 +242,36 @@ fn expression_atom(s: &str) -> IResult<ast::Expr> {
),
map(float, ast::Expr::F32Const),
map(integer, ast::Expr::I32Const),
map(
tuple((
terminated(ws(position), tag("select")),
preceded(ws(char('(')), expression),
preceded(ws(char(',')), expression),
delimited(ws(char(',')), expression, ws(char(')'))),
)),
|(position, condition, if_true, if_false)| ast::Expr::Select {
position,
condition: Box::new(condition.into()),
if_true: Box::new(if_true.into()),
if_false: Box::new(if_false.into()),
},
),
map(
tuple((
ws(position),
identifier,
delimited(
ws(char('(')),
separated_list0(ws(char(',')), expression),
ws(char(')')),
),
)),
|(position, name, params)| ast::Expr::FuncCall {
position,
name,
params: params.into_iter().map(|p| p.into()).collect(),
},
),
map(ws(pair(position, identifier)), |(position, name)| {
ast::Expr::Variable {
position,
@@ -333,33 +363,8 @@ fn expression_sum(s: &str) -> IResult<ast::Expr> {
)(s)
}
fn expression_bit(s: &str) -> IResult<ast::Expr> {
let (s, mut init) = map(expression_sum, Some)(s)?;
fold_many0(
pair(
ws(pair(position, alt((char('&'), char('|'), char('^'))))),
expression_sum,
),
move || init.take().unwrap(),
|left, ((position, op), right)| {
let op = match op {
'&' => ast::BinOp::And,
'|' => ast::BinOp::Or,
'^' => ast::BinOp::Xor,
_ => unreachable!(),
};
ast::Expr::BinOp {
position,
op,
left: Box::new(left.into()),
right: Box::new(right.into()),
}
},
)(s)
}
fn expression_cmp(s: &str) -> IResult<ast::Expr> {
let (s, mut init) = map(expression_bit, Some)(s)?;
let (s, mut init) = map(expression_sum, Some)(s)?;
fold_many0(
pair(
ws(pair(
@@ -373,7 +378,7 @@ fn expression_cmp(s: &str) -> IResult<ast::Expr> {
tag(">"),
)),
)),
expression_bit,
expression_sum,
),
move || init.take().unwrap(),
|left, ((position, op), right)| {
@@ -396,6 +401,31 @@ fn expression_cmp(s: &str) -> IResult<ast::Expr> {
)(s)
}
fn expression_bit(s: &str) -> IResult<ast::Expr> {
let (s, mut init) = map(expression_cmp, Some)(s)?;
fold_many0(
pair(
ws(pair(position, alt((char('&'), char('|'), char('^'))))),
expression_cmp,
),
move || init.take().unwrap(),
|left, ((position, op), right)| {
let op = match op {
'&' => ast::BinOp::And,
'|' => ast::BinOp::Or,
'^' => ast::BinOp::Xor,
_ => unreachable!(),
};
ast::Expr::BinOp {
position,
op,
left: Box::new(left.into()),
right: Box::new(right.into()),
}
},
)(s)
}
fn block_expression(s: &str) -> IResult<ast::Expr> {
loop_(s)
}

View File

@@ -167,7 +167,7 @@ fn tc_expression<'a>(context: &mut Context<'a>, expr: &mut ast::Expression<'a>)
} else {
return Err(Error {
position,
message: "Variable not found".into(),
message: format!("Variable '{}' not found", name),
});
}
}
@@ -228,6 +228,73 @@ fn tc_expression<'a>(context: &mut Context<'a>, expr: &mut ast::Expression<'a>)
}
Some(type_)
}
ast::Expr::FuncCall {
position,
name,
ref mut params,
} => {
if let Some((ptypes, rtype)) = builtin_function_types(name) {
if params.len() != ptypes.len() {
return Err(Error {
position,
message: format!(
"Expected {} parameters but found {}",
ptypes.len(),
params.len()
),
});
}
for (index, (ptype, param)) in ptypes.iter().zip(params.iter_mut()).enumerate() {
tc_expression(context, param)?;
if param.type_.is_none() || param.type_.unwrap() != *ptype {
return Err(Error {
position,
message: format!(
"Param {} is {:?} but should be {:?}",
index + 1,
param.type_,
ptype
),
});
}
}
rtype
} else {
return Err(Error {
position,
message: format!("Unknown function '{}'", name),
});
}
}
ast::Expr::Select {
position,
ref mut condition,
ref mut if_true,
ref mut if_false,
} => {
tc_expression(context, condition)?;
tc_expression(context, if_true)?;
tc_expression(context, if_false)?;
if condition.type_ != Some(ast::Type::I32) {
return Err(Error {
position,
message: "Condition of select has to be of type i32".into(),
});
}
if if_true.type_ != if_false.type_ {
return Err(Error {
position,
message: "Types of select branches differ".into(),
});
}
if if_true.type_.is_none() {
return Err(Error {
position,
message: "Types of select branches cannot be void".into(),
});
}
if_true.type_
}
};
Ok(())
}
@@ -246,3 +313,13 @@ fn tc_mem_location<'a>(
}
Ok(())
}
fn builtin_function_types(name: &str) -> Option<(&'static [ast::Type], Option<ast::Type>)> {
use ast::Type::*;
let types: (&'static [ast::Type], Option<ast::Type>) = match name {
"sqrt" => (&[F32], Some(F32)),
"abs" => (&[F32], Some(F32)),
_ => return None,
};
Some(types)
}