use crate::ast; use nom::{ branch::alt, bytes::complete::tag, character::complete::{alpha1, alphanumeric1, char, digit1, multispace0}, combinator::{self, cut, map, map_res, not, opt, peek, recognize, success, value}, error::VerboseError, multi::{fold_many0, many0, separated_list0}, sequence::{delimited, pair, preceded, separated_pair, terminated}, Finish, }; type IResult<'a, O> = nom::IResult<&'a str, O, VerboseError<&'a str>>; pub fn parse(s: &str) -> Result> { let (_, script) = combinator::all_consuming(terminated(script, multispace0))(s).finish()?; Ok(script) } fn script(s: &str) -> IResult { let (s, items) = many0(top_level_item)(s)?; Ok((s, ast::Script { items })) } fn top_level_item(s: &str) -> IResult { alt(( map(function, ast::TopLevelItem::Function), map(global_var, ast::TopLevelItem::GlobalVar), ))(s) } fn global_var(s: &str) -> IResult { let (s, vis) = visibility(s)?; let (s, position) = ws(position)(s)?; let (s, name) = identifier(s)?; let (s, type_) = preceded(ws(char(':')), type_)(s)?; let (s, _) = ws(char(';'))(s)?; Ok(( s, ast::GlobalVar { position, visibility: vis, name: name, type_, }, )) } fn function(s: &str) -> IResult { let (s, vis) = visibility(s)?; let (s, _) = ws(tag("fn"))(s)?; cut(move |s| { let (s, position) = ws(position)(s)?; let (s, name) = identifier(s)?; let (s, params) = delimited( ws(char('(')), separated_list0( ws(char(',')), pair(map(identifier, |i| i), preceded(ws(tag(":")), type_)), ), ws(char(')')), )(s)?; let (s, type_) = opt(preceded(ws(tag("->")), type_))(s)?; let (s, body) = block(s)?; Ok(( s, ast::Function { position, visibility: vis, name: name, params, type_, body, }, )) })(s) } fn block(s: &str) -> IResult { let (s, (statements, final_expression)) = delimited( ws(char('{')), pair(many0(statement), opt(expression)), ws(char('}')), )(s)?; Ok(( s, ast::Block { statements, final_expression, }, )) } fn statement(s: &str) -> IResult { alt(( map(local_var, ast::Statement::LocalVariable), map( terminated(expression, ws(char(';'))), ast::Statement::Expression, ), map( terminated(block_expression, not(peek(ws(char('}'))))), ast::Statement::Expression, ), map( terminated( pair(mem_location, preceded(ws(char('=')), expression)), ws(char(';')), ), |(mem_location, value)| ast::Statement::Poke { mem_location, value, }, ), ))(s) } fn local_var(s: &str) -> IResult { let (s, _) = ws(tag("let"))(s)?; let (s, position) = ws(position)(s)?; let (s, name) = identifier(s)?; let (s, type_) = opt(preceded(ws(char(':')), type_))(s)?; let (s, value) = opt(preceded(ws(char('=')), expression))(s)?; let (s, _) = ws(char(';'))(s)?; Ok(( s, ast::LocalVariable { position, name: name, type_, value, }, )) } fn mem_location(s: &str) -> IResult { let (s, position) = ws(position)(s)?; let (s, left) = expression(s)?; let (s, size) = map(ws(alt((char('?'), char('!')))), |op| match op { '?' => ast::MemSize::Byte, '!' => ast::MemSize::Word, _ => unreachable!(), })(s)?; let (s, right) = expression(s)?; Ok(( s, ast::MemoryLocation { position, size, left, right, }, )) } fn expression(s: &str) -> IResult { expression_cmp(s) } fn expression_atom(s: &str) -> IResult { alt(( branch_if, block_expression, map( separated_pair(pair(ws(position), identifier), ws(tag(":=")), expression), |((position, name), value)| ast::Expression::LocalTee { position, name: name, value: Box::new(value), }, ), map(integer, |v| ast::Expression::I32Const(v)), map(ws(pair(position, identifier)), |(position, name)| { ast::Expression::Variable { position, name: name, } }), delimited(ws(char('(')), cut(expression), ws(char(')'))), ))(s) } fn branch_if(s: &str) -> IResult { let (s, position) = ws(position)(s)?; let (s, _) = tag("branch_if")(s)?; cut(move |s| { let (s, condition) = expression(s)?; let (s, _) = ws(char(':'))(s)?; let (s, label) = identifier(s)?; Ok(( s, ast::Expression::BranchIf { position, condition: Box::new(condition), label: label, }, )) })(s) } fn expression_product(s: &str) -> IResult { let (s, mut init) = map(expression_atom, Some)(s)?; fold_many0( pair( ws(pair(position, alt((char('*'), char('/'), char('%'))))), expression_atom, ), move || init.take().unwrap(), |left, ((position, op), right)| { let op = match op { '*' => ast::BinOp::Mul, '/' => ast::BinOp::Div, '%' => ast::BinOp::Rem, _ => unreachable!(), }; ast::Expression::BinOp { position, op, left: Box::new(left), right: Box::new(right), } }, )(s) } fn expression_sum(s: &str) -> IResult { let (s, mut init) = map(expression_product, Some)(s)?; fold_many0( pair( ws(pair(position, alt((char('+'), char('-'))))), expression_product, ), move || init.take().unwrap(), |left, ((position, op), right)| { let op = if op == '+' { ast::BinOp::Add } else { ast::BinOp::Sub }; ast::Expression::BinOp { position, op, left: Box::new(left), right: Box::new(right), } }, )(s) } fn expression_bit(s: &str) -> IResult { let (s, mut init) = map(expression_sum, Some)(s)?; fold_many0( pair( ws(pair(position, alt((char('&'), char('|'), char('^'))))), expression_sum, ), move || init.take().unwrap(), |left, ((position, op), right)| { let op = match op { '&' => ast::BinOp::And, '|' => ast::BinOp::Or, '^' => ast::BinOp::Xor, _ => unreachable!(), }; ast::Expression::BinOp { position, op, left: Box::new(left), right: Box::new(right), } }, )(s) } fn expression_cmp(s: &str) -> IResult { let (s, mut init) = map(expression_bit, Some)(s)?; fold_many0( pair( ws(pair( position, alt(( tag("=="), tag("!="), tag("<="), tag("<"), tag(">="), tag(">"), )), )), expression_bit, ), move || init.take().unwrap(), |left, ((position, op), right)| { let op = match op { "==" => ast::BinOp::Eq, "!=" => ast::BinOp::Ne, "<=" => ast::BinOp::Le, "<" => ast::BinOp::Lt, ">=" => ast::BinOp::Ge, ">" => ast::BinOp::Gt, _ => unreachable!(), }; ast::Expression::BinOp { position, op, left: Box::new(left), right: Box::new(right), } }, )(s) } fn block_expression(s: &str) -> IResult { loop_(s) } fn loop_(s: &str) -> IResult { let (s, position) = ws(position)(s)?; let (s, _) = tag("loop")(s)?; cut(move |s| { let (s, label) = identifier(s)?; let (s, block) = block(s)?; Ok(( s, ast::Expression::Loop { position, label: label, block: Box::new(block), }, )) })(s) } fn integer(s: &str) -> IResult { ws(map_res( recognize(pair(opt(char('-')), digit1)), |n: &str| n.parse::(), ))(s) } fn visibility(s: &str) -> IResult { ws(alt(( value(ast::Visibility::Export, tag("export")), value(ast::Visibility::Import, tag("import")), success(ast::Visibility::Local), )))(s) } fn type_(s: &str) -> IResult { ws(alt(( value(ast::Type::I32, tag("i32")), value(ast::Type::I64, tag("i64")), value(ast::Type::F32, tag("f32")), value(ast::Type::F64, tag("f64")), )))(s) } fn identifier(s: &str) -> IResult<&str> { ws(recognize(pair( alt((alpha1, tag("_"))), many0(alt((alphanumeric1, tag("_")))), )))(s) } fn position(s: &str) -> IResult { Ok((s, ast::Position(s.len()))) } fn ws<'a, F: 'a, O>(inner: F) -> impl FnMut(&'a str) -> IResult where F: FnMut(&'a str) -> IResult<'a, O>, { preceded(multispace0, inner) } #[cfg(test)] mod test { use nom::combinator::all_consuming; #[test] fn identifier() { all_consuming(super::identifier)("_froobaz123").unwrap(); } #[test] fn type_() { all_consuming(super::type_)("i32").unwrap(); all_consuming(super::type_)("i64").unwrap(); all_consuming(super::type_)("f32").unwrap(); all_consuming(super::type_)("f64").unwrap(); } #[test] fn integer() { all_consuming(super::integer)("123").unwrap(); all_consuming(super::integer)("-123").unwrap(); } #[test] fn local_var() { all_consuming(super::local_var)("let foo: i32;").unwrap(); all_consuming(super::local_var)("let bar = 42;").unwrap(); } #[test] fn function() { all_consuming(super::function)("export fn foo(a: i32, b: f32) -> i32 { let x = 42; x }") .unwrap(); } #[test] fn loop_() { all_consuming(super::loop_)("loop foo { 42 }").unwrap(); all_consuming(super::loop_)("loop foo { i?64 = (i % 320 + time / 10) ^ (i / 320); }") .unwrap(); } #[test] fn block() { all_consuming(super::block)("{loop frame {}}").unwrap(); } #[test] fn expression() { all_consuming(super::expression)("foo + 2 * (bar ^ 23)").unwrap(); all_consuming(super::expression)("i := i + 1").unwrap(); all_consuming(super::expression)("(i := i + 1)").unwrap(); } #[test] fn poke() { all_consuming(super::statement)("i?64 = (i % 320 + time / 10) ^ (i / 320);").unwrap(); } #[test] fn branch_if() { all_consuming(super::branch_if)("branch_if (i := i + 1) < 10: foo").unwrap(); } }