From e9f115ee950553b0491857017ad41e67753bb1ed Mon Sep 17 00:00:00 2001 From: Dennis Ranke Date: Tue, 9 Nov 2021 22:16:40 +0100 Subject: [PATCH] implement table of intrinsics --- src/ast.rs | 27 +++++++-- src/emit.rs | 26 ++++----- src/intrinsics.rs | 140 ++++++++++++++++++++++++++++++++++++++++++++++ src/main.rs | 22 +------- src/parser.rs | 11 +++- src/typecheck.rs | 95 +++++++++++++++---------------- 6 files changed, 229 insertions(+), 92 deletions(-) create mode 100644 src/intrinsics.rs diff --git a/src/ast.rs b/src/ast.rs index d8cf062..588621f 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -1,3 +1,5 @@ +use std::fmt; + use crate::Span; #[derive(Debug)] @@ -29,7 +31,11 @@ pub enum ImportType { type_: Type, mutable: bool, }, - Function { name: String, params: Vec, result: Option } + Function { + name: String, + params: Vec, + result: Option, + }, } #[derive(Debug)] @@ -38,7 +44,7 @@ pub struct GlobalVar { pub name: String, pub value: Expression, pub type_: Option, - pub mutable: bool + pub mutable: bool, } #[derive(Debug)] @@ -130,10 +136,10 @@ pub enum Expr { If { condition: Box, if_true: Box, - if_false: Option> + if_false: Option>, }, Return { - value: Option> + value: Option>, }, Error, } @@ -151,7 +157,7 @@ impl Expr { #[derive(Debug, Clone, Copy)] pub enum UnaryOp { Negate, - Not + Not, } #[derive(Debug, Clone, Copy)] @@ -188,3 +194,14 @@ pub enum Type { F32, F64, } + +impl fmt::Display for Type { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Type::I32 => write!(f, "i32"), + Type::I64 => write!(f, "i64"), + Type::F32 => write!(f, "f32"), + Type::F64 => write!(f, "f64"), + } + } +} diff --git a/src/emit.rs b/src/emit.rs index 88aaab8..da87a8a 100644 --- a/src/emit.rs +++ b/src/emit.rs @@ -6,7 +6,7 @@ use wasm_encoder::{ ValType, }; -use crate::ast; +use crate::{ast, intrinsics::Intrinsics}; pub fn emit(script: &ast::Script) -> Vec { let mut module = Module::new(); @@ -94,6 +94,8 @@ pub fn emit(script: &ast::Script) -> Vec { let mut exports = ExportSection::new(); let mut code = CodeSection::new(); + let intrinsics = Intrinsics::new(); + for func in script.functions.iter() { function_map.insert(func.name.clone(), function_map.len() as u32); } @@ -108,7 +110,7 @@ pub fn emit(script: &ast::Script) -> Vec { ); } - code.function(&emit_function(func, &globals, &function_map)); + code.function(&emit_function(func, &globals, &function_map, &intrinsics)); } module.section(&functions); @@ -158,6 +160,8 @@ fn const_instr(expr: &ast::Expression) -> Instruction { match expr.expr { ast::Expr::I32Const(v) => Instruction::I32Const(v), ast::Expr::F32Const(v) => Instruction::F32Const(v), + ast::Expr::I64Const(v) => Instruction::I64Const(v), + ast::Expr::F64Const(v) => Instruction::F64Const(v), _ => unreachable!(), } } @@ -169,12 +173,14 @@ struct FunctionContext<'a> { locals: &'a HashMap, labels: Vec, deferred_inits: HashMap<&'a str, &'a ast::Expression>, + intrinsics: &'a Intrinsics, } fn emit_function( func: &ast::Function, globals: &HashMap<&str, u32>, functions: &HashMap, + intrinsics: &Intrinsics, ) -> Function { let mut locals = Vec::new(); collect_locals_expr(&func.body, &mut locals); @@ -199,6 +205,7 @@ fn emit_function( locals: &local_map, labels: vec![], deferred_inits: HashMap::new(), + intrinsics, }; emit_expression(&mut context, &func.body); @@ -563,6 +570,7 @@ fn emit_expression<'a>(ctx: &mut FunctionContext<'a>, expr: &'a ast::Expression) for param in params { emit_expression(ctx, param); } + if let Some(index) = ctx.functions.get(name) { ctx.function.instruction(&Instruction::Call(*index)); } else { @@ -571,7 +579,7 @@ fn emit_expression<'a>(ctx: &mut FunctionContext<'a>, expr: &'a ast::Expression) types.push(param.type_.unwrap()); } ctx.function - .instruction(&builtin_function(name, &types).unwrap()); + .instruction(&ctx.intrinsics.get_instr(name, &types).unwrap()); } } ast::Expr::Select { @@ -634,15 +642,3 @@ fn map_block_type(t: Option) -> BlockType { BlockType::Empty } } - -fn builtin_function(name: &str, params: &[ast::Type]) -> Option> { - use ast::Type::*; - let inst = match (name, params) { - ("sqrt", &[F32]) => Instruction::F32Sqrt, - ("abs", &[F32]) => Instruction::F32Abs, - ("min", &[F32, F32]) => Instruction::F32Min, - ("max", &[F32, F32]) => Instruction::F32Max, - _ => return None, - }; - Some(inst) -} diff --git a/src/intrinsics.rs b/src/intrinsics.rs new file mode 100644 index 0000000..e9a9b8c --- /dev/null +++ b/src/intrinsics.rs @@ -0,0 +1,140 @@ +use crate::ast::Type; +use std::collections::HashMap; +use wasm_encoder as enc; + +pub struct Intrinsics(HashMap, (Type, enc::Instruction<'static>)>>); + +impl Intrinsics { + pub fn new() -> Intrinsics { + let mut i = Intrinsics(HashMap::new()); + i.add_instructions(); + i + } + + pub fn find_types( + &self, + name: &str, + ) -> Option, Option>> { + self.0.get(name).map(|types| { + types + .iter() + .map(|(params, (ret, _))| (params.clone(), Some(*ret))) + .collect() + }) + } + + pub fn get_instr(&self, name: &str, params: &[Type]) -> Option> { + self.0 + .get(name) + .and_then(|types| types.get(params)) + .map(|(_, i)| i.clone()) + } + + fn add_instructions(&mut self) { + use enc::Instruction as I; + use Type::*; + self.inst("i32.rotl", &[I32, I32], I32, I::I32Rotl); + self.inst("i32.rotr", &[I32, I32], I32, I::I32Rotr); + self.inst("i32.clz", &[I32], I32, I::I32Clz); + self.inst("i32.ctz", &[I32], I32, I::I32Ctz); + self.inst("i32.popcnt", &[I32], I32, I::I32Popcnt); + + self.inst("i64.rotl", &[I64, I64], I64, I::I64Rotl); + self.inst("i64.rotr", &[I64, I64], I64, I::I64Rotr); + self.inst("i64.clz", &[I64], I64, I::I64Clz); + self.inst("i64.ctz", &[I64], I64, I::I64Ctz); + self.inst("i64.popcnt", &[I64], I64, I::I64Popcnt); + + self.inst("f32/sqrt", &[F32], F32, I::F32Sqrt); + self.inst("f32/min", &[F32, F32], F32, I::F32Min); + self.inst("f32/max", &[F32, F32], F32, I::F32Max); + self.inst("f32/ceil", &[F32], F32, I::F32Ceil); + self.inst("f32/floor", &[F32], F32, I::F32Floor); + self.inst("f32/trunc", &[F32], F32, I::F32Trunc); + self.inst("f32/nearest", &[F32], F32, I::F32Nearest); + self.inst("f32/abs", &[F32], F32, I::F32Abs); + self.inst("f32.copysign", &[F32, F32], F32, I::F32Copysign); + + self.inst("f64/sqrt", &[F64], F64, I::F64Sqrt); + self.inst("f64/min", &[F64, F64], F64, I::F64Min); + self.inst("f64/max", &[F64, F64], F64, I::F64Max); + self.inst("f64/ceil", &[F64], F64, I::F64Ceil); + self.inst("f64/floor", &[F64], F64, I::F64Floor); + self.inst("f64/trunc", &[F64], F64, I::F64Trunc); + self.inst("f64/nearest", &[F64], F64, I::F64Nearest); + self.inst("f64/abs", &[F64], F64, I::F64Abs); + self.inst("f64.copysign", &[F64, F64], F64, I::F64Copysign); + + self.inst("i32.wrap_i64", &[I64], I32, I::I32WrapI64); + self.inst("i64.extend_i32_s", &[I32], I64, I::I64ExtendI32S); + self.inst("i64.extend_i32_u", &[I32], I64, I::I64ExtendI32U); + + self.inst("i32.trunc_f32_s", &[F32], I32, I::I32TruncF32S); + self.inst("i32.trunc_f64_s", &[F64], I32, I::I32TruncF64S); + self.inst("i64.trunc_f32_s", &[F32], I64, I::I64TruncF32S); + self.inst("i64.trunc_f64_s", &[F64], I64, I::I64TruncF64S); + + self.inst("i32.trunc_f32_u", &[F32], I32, I::I32TruncF32U); + self.inst("i32.trunc_f64_u", &[F64], I32, I::I32TruncF64U); + self.inst("i64.trunc_f32_u", &[F32], I64, I::I64TruncF32U); + self.inst("i64.trunc_f64_u", &[F64], I64, I::I64TruncF64U); + + self.inst("f32.demote_f64", &[F64], F32, I::F32DemoteF64); + self.inst("f64.promote_f32", &[F32], F64, I::F64PromoteF32); + + self.inst("f32.convert_i32_s", &[I32], F32, I::F32ConvertI32S); + self.inst("f32.convert_i64_s", &[I64], F32, I::F32ConvertI32S); + self.inst("f64.convert_i32_s", &[I32], F64, I::F32ConvertI32S); + self.inst("f64.convert_i64_s", &[I64], F64, I::F32ConvertI32S); + + self.inst("f32.convert_i32_u", &[I32], F32, I::F32ConvertI32U); + self.inst("f32.convert_i64_u", &[I64], F32, I::F32ConvertI32U); + self.inst("f64.convert_i32_u", &[I32], F64, I::F32ConvertI32U); + self.inst("f64.convert_i64_u", &[I64], F64, I::F32ConvertI32U); + + self.inst("i32.reinterpret_f32", &[F32], I32, I::I32ReinterpretF32); + self.inst("i64.reinterpret_f64", &[F64], I64, I::I64ReinterpretF64); + self.inst("f32.reinterpret_i32", &[I32], F32, I::F32ReinterpretI32); + self.inst("f64.reinterpret_i64", &[I64], F64, I::F64ReinterpretI64); + + self.inst("i32.extend8_s", &[I32], I32, I::I32Extend8S); + self.inst("i32.extend16_s", &[I32], I32, I::I32Extend16S); + self.inst("i64.extend8_s", &[I64], I64, I::I64Extend8S); + self.inst("i64.extend16_s", &[I64], I64, I::I64Extend16S); + self.inst("i64.extend32_s", &[I64], I64, I::I64Extend32S); + + self.inst("i32.trunc_sat_f32_s", &[F32], I32, I::I32TruncSatF32S); + self.inst("i32.trunc_sat_f32_u", &[F32], I32, I::I32TruncSatF32U); + self.inst("i32.trunc_sat_f64_s", &[F64], I32, I::I32TruncSatF64S); + self.inst("i32.trunc_sat_f64_u", &[F64], I32, I::I32TruncSatF64U); + self.inst("i64.trunc_sat_f32_s", &[F32], I64, I::I64TruncSatF32S); + self.inst("i64.trunc_sat_f32_u", &[F32], I64, I::I64TruncSatF32U); + self.inst("i64.trunc_sat_f64_s", &[F64], I64, I::I64TruncSatF64S); + self.inst("i64.trunc_sat_f64_u", &[F64], I64, I::I64TruncSatF64U); + } + + fn inst(&mut self, name: &str, params: &[Type], ret: Type, ins: enc::Instruction<'static>) { + if let Some(slash_idx) = name.find('/') { + self.insert(name[(slash_idx + 1)..].to_string(), params, ret, &ins); + let mut full_name = name[..slash_idx].to_string(); + full_name.push('.'); + full_name += &name[(slash_idx + 1)..]; + self.insert(full_name, params, ret, &ins); + } else { + self.insert(name.to_string(), params, ret, &ins); + } + } + + fn insert( + &mut self, + name: String, + params: &[Type], + ret: Type, + ins: &enc::Instruction<'static>, + ) { + self.0 + .entry(name) + .or_default() + .insert(params.to_vec(), (ret, ins.clone())); + } +} diff --git a/src/main.rs b/src/main.rs index 8272dab..2f89a95 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,6 +7,7 @@ mod constfold; mod emit; mod parser; mod typecheck; +mod intrinsics; type Span = std::ops::Range; @@ -35,26 +36,5 @@ fn main() -> Result<()> { filename.set_extension("wasm"); File::create(filename)?.write_all(&wasm)?; - println!("Size of code section: {} bytes", code_section_size(&wasm)?); - Ok(()) } - -fn code_section_size(wasm: &[u8]) -> Result { - for payload in wasmparser::Parser::new(0).parse_all(wasm) { - match payload? { - wasmparser::Payload::CodeSectionStart { range, .. } => { - let size = range.end - range.start; - let section_header_size = match size { - 0..=127 => 2, - 128..=16383 => 3, - _ => 4, - }; - return Ok(size + section_header_size); - } - _ => (), - } - } - - bail!("No code section found"); -} diff --git a/src/parser.rs b/src/parser.rs index a25bc5d..ef8ce74 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -201,7 +201,16 @@ fn lexer() -> impl Parser, Error = Simple> { let ctrl = one_of("(){};,:?!".chars()).map(Token::Ctrl); - let ident = text::ident().map(|ident: String| match ident.as_str() { + fn ident() -> impl Parser> + Copy + Clone { + filter(|c: &char| c.is_ascii_alphabetic() || *c == '_') + .map(Some) + .chain::, _>( + filter(|c: &char| c.is_ascii_alphanumeric() || *c == '_' || *c == '.').repeated(), + ) + .collect() + } + + let ident = ident().map(|ident: String| match ident.as_str() { "import" => Token::Import, "export" => Token::Export, "fn" => Token::Fn, diff --git a/src/typecheck.rs b/src/typecheck.rs index 130fd56..07a0a08 100644 --- a/src/typecheck.rs +++ b/src/typecheck.rs @@ -2,6 +2,7 @@ use ariadne::{Color, Label, Report, ReportKind, Source}; use std::collections::HashMap; use crate::ast; +use crate::intrinsics::Intrinsics; use crate::Span; use ast::Type::*; @@ -22,6 +23,7 @@ pub fn tc_script(script: &mut ast::Script, source: &str) -> Result<()> { local_vars: HashMap::new(), block_stack: Vec::new(), return_type: None, + intrinsics: Intrinsics::new(), }; let mut result = Ok(()); @@ -164,6 +166,7 @@ struct Context<'a> { local_vars: Vars, block_stack: Vec, return_type: Option, + intrinsics: Intrinsics, } fn report_duplicate_definition( @@ -382,10 +385,7 @@ fn tc_expression(context: &mut Context, expr: &mut ast::Expression) -> Result<() ast::Expr::I64Const(_) => Some(ast::Type::I64), ast::Expr::F32Const(_) => Some(ast::Type::F32), ast::Expr::F64Const(_) => Some(ast::Type::F64), - ast::Expr::UnaryOp { - op, - ref mut value, - } => { + ast::Expr::UnaryOp { op, ref mut value } => { tc_expression(context, value)?; if value.type_.is_none() { return expected_type(&value.span, context.source); @@ -395,7 +395,15 @@ fn tc_expression(context: &mut Context, expr: &mut ast::Expression) -> Result<() Some(match (value.type_.unwrap(), op) { (t, Negate) => t, (I32 | I64, Not) => I32, - (_, Not) => return type_mismatch(Some(I32), &expr.span, value.type_, &value.span, context.source) + (_, Not) => { + return type_mismatch( + Some(I32), + &expr.span, + value.type_, + &value.span, + context.source, + ) + } }) } ast::Expr::BinOp { @@ -557,46 +565,43 @@ fn tc_expression(context: &mut Context, expr: &mut ast::Expression) -> Result<() } => { for param in params.iter_mut() { tc_expression(context, param)?; + if param.type_.is_none() { + return expected_type(¶m.span, context.source); + } } - if let Some((ptypes, rtype)) = context + if let Some(type_map) = context .functions .get(name) - .map(|fnc| (fnc.params.as_slice(), fnc.type_)) - .or_else(|| builtin_function_types(name)) + .map(|fnc| HashMap::from_iter([(fnc.params.clone(), fnc.type_)])) + .or_else(|| context.intrinsics.find_types(name)) { - if params.len() != ptypes.len() { - Report::build(ReportKind::Error, (), expr.span.start) - .with_message(format!( - "Expected {} parameters but found {}", - ptypes.len(), - params.len() - )) - .with_label( - Label::new(expr.span.clone()) - .with_message(format!( - "Expected {} parameters but found {}", - ptypes.len(), - params.len() - )) - .with_color(Color::Red), - ) + if let Some(rtype) = + type_map.get(¶ms.iter().map(|p| p.type_.unwrap()).collect::>()) + { + *rtype + } else { + let mut report = Report::build(ReportKind::Error, (), expr.span.start) + .with_message("No matching function found"); + for (params, rtype) in type_map { + let param_str: Vec<_> = params.into_iter().map(|t| t.to_string()).collect(); + let msg = format!( + "Found {}({}){}", + name, + param_str.join(", "), + if let Some(rtype) = rtype { + format!(" -> {}", rtype) + } else { + String::new() + } + ); + report = report.with_label(Label::new(expr.span.clone()).with_message(msg)); + } + report .finish() .eprint(Source::from(context.source)) .unwrap(); return Err(()); } - for (ptype, param) in ptypes.iter().zip(params.iter()) { - if param.type_ != Some(*ptype) { - return type_mismatch( - Some(*ptype), - &expr.span, - param.type_, - ¶m.span, - context.source, - ); - } - } - rtype } else { Report::build(ReportKind::Error, (), expr.span.start) .with_message(format!("Unknown function {}", name)) @@ -718,13 +723,15 @@ fn tc_const(expr: &mut ast::Expression, source: &str) -> Result<()> { use ast::Expr::*; expr.type_ = Some(match expr.expr { I32Const(_) => I32, + I64Const(_) => I64, F32Const(_) => F32, + F64Const(_) => F64, _ => { Report::build(ReportKind::Error, (), expr.span.start) - .with_message("Expected I32 constant") + .with_message("Expected constant value") .with_label( Label::new(expr.span.clone()) - .with_message("Expected I32 constant") + .with_message("Expected constant value") .with_color(Color::Red), ) .finish() @@ -735,15 +742,3 @@ fn tc_const(expr: &mut ast::Expression, source: &str) -> Result<()> { }); Ok(()) } - -fn builtin_function_types(name: &str) -> Option<(&'static [ast::Type], Option)> { - use ast::Type::*; - let types: (&'static [ast::Type], Option) = match name { - "sqrt" => (&[F32], Some(F32)), - "abs" => (&[F32], Some(F32)), - "min" => (&[F32, F32], Some(F32)), - "max" => (&[F32, F32], Some(F32)), - _ => return None, - }; - Some(types) -}