implement table of intrinsics

This commit is contained in:
2021-11-09 22:16:40 +01:00
parent 41ec5a770f
commit e9f115ee95
6 changed files with 229 additions and 92 deletions

View File

@@ -1,3 +1,5 @@
use std::fmt;
use crate::Span; use crate::Span;
#[derive(Debug)] #[derive(Debug)]
@@ -29,7 +31,11 @@ pub enum ImportType {
type_: Type, type_: Type,
mutable: bool, mutable: bool,
}, },
Function { name: String, params: Vec<Type>, result: Option<Type> } Function {
name: String,
params: Vec<Type>,
result: Option<Type>,
},
} }
#[derive(Debug)] #[derive(Debug)]
@@ -38,7 +44,7 @@ pub struct GlobalVar {
pub name: String, pub name: String,
pub value: Expression, pub value: Expression,
pub type_: Option<Type>, pub type_: Option<Type>,
pub mutable: bool pub mutable: bool,
} }
#[derive(Debug)] #[derive(Debug)]
@@ -130,10 +136,10 @@ pub enum Expr {
If { If {
condition: Box<Expression>, condition: Box<Expression>,
if_true: Box<Expression>, if_true: Box<Expression>,
if_false: Option<Box<Expression>> if_false: Option<Box<Expression>>,
}, },
Return { Return {
value: Option<Box<Expression>> value: Option<Box<Expression>>,
}, },
Error, Error,
} }
@@ -151,7 +157,7 @@ impl Expr {
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
pub enum UnaryOp { pub enum UnaryOp {
Negate, Negate,
Not Not,
} }
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
@@ -188,3 +194,14 @@ pub enum Type {
F32, F32,
F64, F64,
} }
impl fmt::Display for Type {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Type::I32 => write!(f, "i32"),
Type::I64 => write!(f, "i64"),
Type::F32 => write!(f, "f32"),
Type::F64 => write!(f, "f64"),
}
}
}

View File

@@ -6,7 +6,7 @@ use wasm_encoder::{
ValType, ValType,
}; };
use crate::ast; use crate::{ast, intrinsics::Intrinsics};
pub fn emit(script: &ast::Script) -> Vec<u8> { pub fn emit(script: &ast::Script) -> Vec<u8> {
let mut module = Module::new(); let mut module = Module::new();
@@ -94,6 +94,8 @@ pub fn emit(script: &ast::Script) -> Vec<u8> {
let mut exports = ExportSection::new(); let mut exports = ExportSection::new();
let mut code = CodeSection::new(); let mut code = CodeSection::new();
let intrinsics = Intrinsics::new();
for func in script.functions.iter() { for func in script.functions.iter() {
function_map.insert(func.name.clone(), function_map.len() as u32); function_map.insert(func.name.clone(), function_map.len() as u32);
} }
@@ -108,7 +110,7 @@ pub fn emit(script: &ast::Script) -> Vec<u8> {
); );
} }
code.function(&emit_function(func, &globals, &function_map)); code.function(&emit_function(func, &globals, &function_map, &intrinsics));
} }
module.section(&functions); module.section(&functions);
@@ -158,6 +160,8 @@ fn const_instr(expr: &ast::Expression) -> Instruction {
match expr.expr { match expr.expr {
ast::Expr::I32Const(v) => Instruction::I32Const(v), ast::Expr::I32Const(v) => Instruction::I32Const(v),
ast::Expr::F32Const(v) => Instruction::F32Const(v), ast::Expr::F32Const(v) => Instruction::F32Const(v),
ast::Expr::I64Const(v) => Instruction::I64Const(v),
ast::Expr::F64Const(v) => Instruction::F64Const(v),
_ => unreachable!(), _ => unreachable!(),
} }
} }
@@ -169,12 +173,14 @@ struct FunctionContext<'a> {
locals: &'a HashMap<String, u32>, locals: &'a HashMap<String, u32>,
labels: Vec<String>, labels: Vec<String>,
deferred_inits: HashMap<&'a str, &'a ast::Expression>, deferred_inits: HashMap<&'a str, &'a ast::Expression>,
intrinsics: &'a Intrinsics,
} }
fn emit_function( fn emit_function(
func: &ast::Function, func: &ast::Function,
globals: &HashMap<&str, u32>, globals: &HashMap<&str, u32>,
functions: &HashMap<String, u32>, functions: &HashMap<String, u32>,
intrinsics: &Intrinsics,
) -> Function { ) -> Function {
let mut locals = Vec::new(); let mut locals = Vec::new();
collect_locals_expr(&func.body, &mut locals); collect_locals_expr(&func.body, &mut locals);
@@ -199,6 +205,7 @@ fn emit_function(
locals: &local_map, locals: &local_map,
labels: vec![], labels: vec![],
deferred_inits: HashMap::new(), deferred_inits: HashMap::new(),
intrinsics,
}; };
emit_expression(&mut context, &func.body); emit_expression(&mut context, &func.body);
@@ -563,6 +570,7 @@ fn emit_expression<'a>(ctx: &mut FunctionContext<'a>, expr: &'a ast::Expression)
for param in params { for param in params {
emit_expression(ctx, param); emit_expression(ctx, param);
} }
if let Some(index) = ctx.functions.get(name) { if let Some(index) = ctx.functions.get(name) {
ctx.function.instruction(&Instruction::Call(*index)); ctx.function.instruction(&Instruction::Call(*index));
} else { } else {
@@ -571,7 +579,7 @@ fn emit_expression<'a>(ctx: &mut FunctionContext<'a>, expr: &'a ast::Expression)
types.push(param.type_.unwrap()); types.push(param.type_.unwrap());
} }
ctx.function ctx.function
.instruction(&builtin_function(name, &types).unwrap()); .instruction(&ctx.intrinsics.get_instr(name, &types).unwrap());
} }
} }
ast::Expr::Select { ast::Expr::Select {
@@ -634,15 +642,3 @@ fn map_block_type(t: Option<ast::Type>) -> BlockType {
BlockType::Empty BlockType::Empty
} }
} }
fn builtin_function(name: &str, params: &[ast::Type]) -> Option<Instruction<'static>> {
use ast::Type::*;
let inst = match (name, params) {
("sqrt", &[F32]) => Instruction::F32Sqrt,
("abs", &[F32]) => Instruction::F32Abs,
("min", &[F32, F32]) => Instruction::F32Min,
("max", &[F32, F32]) => Instruction::F32Max,
_ => return None,
};
Some(inst)
}

140
src/intrinsics.rs Normal file
View File

@@ -0,0 +1,140 @@
use crate::ast::Type;
use std::collections::HashMap;
use wasm_encoder as enc;
pub struct Intrinsics(HashMap<String, HashMap<Vec<Type>, (Type, enc::Instruction<'static>)>>);
impl Intrinsics {
pub fn new() -> Intrinsics {
let mut i = Intrinsics(HashMap::new());
i.add_instructions();
i
}
pub fn find_types(
&self,
name: &str,
) -> Option<HashMap<Vec<Type>, Option<Type>>> {
self.0.get(name).map(|types| {
types
.iter()
.map(|(params, (ret, _))| (params.clone(), Some(*ret)))
.collect()
})
}
pub fn get_instr(&self, name: &str, params: &[Type]) -> Option<enc::Instruction<'static>> {
self.0
.get(name)
.and_then(|types| types.get(params))
.map(|(_, i)| i.clone())
}
fn add_instructions(&mut self) {
use enc::Instruction as I;
use Type::*;
self.inst("i32.rotl", &[I32, I32], I32, I::I32Rotl);
self.inst("i32.rotr", &[I32, I32], I32, I::I32Rotr);
self.inst("i32.clz", &[I32], I32, I::I32Clz);
self.inst("i32.ctz", &[I32], I32, I::I32Ctz);
self.inst("i32.popcnt", &[I32], I32, I::I32Popcnt);
self.inst("i64.rotl", &[I64, I64], I64, I::I64Rotl);
self.inst("i64.rotr", &[I64, I64], I64, I::I64Rotr);
self.inst("i64.clz", &[I64], I64, I::I64Clz);
self.inst("i64.ctz", &[I64], I64, I::I64Ctz);
self.inst("i64.popcnt", &[I64], I64, I::I64Popcnt);
self.inst("f32/sqrt", &[F32], F32, I::F32Sqrt);
self.inst("f32/min", &[F32, F32], F32, I::F32Min);
self.inst("f32/max", &[F32, F32], F32, I::F32Max);
self.inst("f32/ceil", &[F32], F32, I::F32Ceil);
self.inst("f32/floor", &[F32], F32, I::F32Floor);
self.inst("f32/trunc", &[F32], F32, I::F32Trunc);
self.inst("f32/nearest", &[F32], F32, I::F32Nearest);
self.inst("f32/abs", &[F32], F32, I::F32Abs);
self.inst("f32.copysign", &[F32, F32], F32, I::F32Copysign);
self.inst("f64/sqrt", &[F64], F64, I::F64Sqrt);
self.inst("f64/min", &[F64, F64], F64, I::F64Min);
self.inst("f64/max", &[F64, F64], F64, I::F64Max);
self.inst("f64/ceil", &[F64], F64, I::F64Ceil);
self.inst("f64/floor", &[F64], F64, I::F64Floor);
self.inst("f64/trunc", &[F64], F64, I::F64Trunc);
self.inst("f64/nearest", &[F64], F64, I::F64Nearest);
self.inst("f64/abs", &[F64], F64, I::F64Abs);
self.inst("f64.copysign", &[F64, F64], F64, I::F64Copysign);
self.inst("i32.wrap_i64", &[I64], I32, I::I32WrapI64);
self.inst("i64.extend_i32_s", &[I32], I64, I::I64ExtendI32S);
self.inst("i64.extend_i32_u", &[I32], I64, I::I64ExtendI32U);
self.inst("i32.trunc_f32_s", &[F32], I32, I::I32TruncF32S);
self.inst("i32.trunc_f64_s", &[F64], I32, I::I32TruncF64S);
self.inst("i64.trunc_f32_s", &[F32], I64, I::I64TruncF32S);
self.inst("i64.trunc_f64_s", &[F64], I64, I::I64TruncF64S);
self.inst("i32.trunc_f32_u", &[F32], I32, I::I32TruncF32U);
self.inst("i32.trunc_f64_u", &[F64], I32, I::I32TruncF64U);
self.inst("i64.trunc_f32_u", &[F32], I64, I::I64TruncF32U);
self.inst("i64.trunc_f64_u", &[F64], I64, I::I64TruncF64U);
self.inst("f32.demote_f64", &[F64], F32, I::F32DemoteF64);
self.inst("f64.promote_f32", &[F32], F64, I::F64PromoteF32);
self.inst("f32.convert_i32_s", &[I32], F32, I::F32ConvertI32S);
self.inst("f32.convert_i64_s", &[I64], F32, I::F32ConvertI32S);
self.inst("f64.convert_i32_s", &[I32], F64, I::F32ConvertI32S);
self.inst("f64.convert_i64_s", &[I64], F64, I::F32ConvertI32S);
self.inst("f32.convert_i32_u", &[I32], F32, I::F32ConvertI32U);
self.inst("f32.convert_i64_u", &[I64], F32, I::F32ConvertI32U);
self.inst("f64.convert_i32_u", &[I32], F64, I::F32ConvertI32U);
self.inst("f64.convert_i64_u", &[I64], F64, I::F32ConvertI32U);
self.inst("i32.reinterpret_f32", &[F32], I32, I::I32ReinterpretF32);
self.inst("i64.reinterpret_f64", &[F64], I64, I::I64ReinterpretF64);
self.inst("f32.reinterpret_i32", &[I32], F32, I::F32ReinterpretI32);
self.inst("f64.reinterpret_i64", &[I64], F64, I::F64ReinterpretI64);
self.inst("i32.extend8_s", &[I32], I32, I::I32Extend8S);
self.inst("i32.extend16_s", &[I32], I32, I::I32Extend16S);
self.inst("i64.extend8_s", &[I64], I64, I::I64Extend8S);
self.inst("i64.extend16_s", &[I64], I64, I::I64Extend16S);
self.inst("i64.extend32_s", &[I64], I64, I::I64Extend32S);
self.inst("i32.trunc_sat_f32_s", &[F32], I32, I::I32TruncSatF32S);
self.inst("i32.trunc_sat_f32_u", &[F32], I32, I::I32TruncSatF32U);
self.inst("i32.trunc_sat_f64_s", &[F64], I32, I::I32TruncSatF64S);
self.inst("i32.trunc_sat_f64_u", &[F64], I32, I::I32TruncSatF64U);
self.inst("i64.trunc_sat_f32_s", &[F32], I64, I::I64TruncSatF32S);
self.inst("i64.trunc_sat_f32_u", &[F32], I64, I::I64TruncSatF32U);
self.inst("i64.trunc_sat_f64_s", &[F64], I64, I::I64TruncSatF64S);
self.inst("i64.trunc_sat_f64_u", &[F64], I64, I::I64TruncSatF64U);
}
fn inst(&mut self, name: &str, params: &[Type], ret: Type, ins: enc::Instruction<'static>) {
if let Some(slash_idx) = name.find('/') {
self.insert(name[(slash_idx + 1)..].to_string(), params, ret, &ins);
let mut full_name = name[..slash_idx].to_string();
full_name.push('.');
full_name += &name[(slash_idx + 1)..];
self.insert(full_name, params, ret, &ins);
} else {
self.insert(name.to_string(), params, ret, &ins);
}
}
fn insert(
&mut self,
name: String,
params: &[Type],
ret: Type,
ins: &enc::Instruction<'static>,
) {
self.0
.entry(name)
.or_default()
.insert(params.to_vec(), (ret, ins.clone()));
}
}

View File

@@ -7,6 +7,7 @@ mod constfold;
mod emit; mod emit;
mod parser; mod parser;
mod typecheck; mod typecheck;
mod intrinsics;
type Span = std::ops::Range<usize>; type Span = std::ops::Range<usize>;
@@ -35,26 +36,5 @@ fn main() -> Result<()> {
filename.set_extension("wasm"); filename.set_extension("wasm");
File::create(filename)?.write_all(&wasm)?; File::create(filename)?.write_all(&wasm)?;
println!("Size of code section: {} bytes", code_section_size(&wasm)?);
Ok(()) Ok(())
} }
fn code_section_size(wasm: &[u8]) -> Result<usize> {
for payload in wasmparser::Parser::new(0).parse_all(wasm) {
match payload? {
wasmparser::Payload::CodeSectionStart { range, .. } => {
let size = range.end - range.start;
let section_header_size = match size {
0..=127 => 2,
128..=16383 => 3,
_ => 4,
};
return Ok(size + section_header_size);
}
_ => (),
}
}
bail!("No code section found");
}

View File

@@ -201,7 +201,16 @@ fn lexer() -> impl Parser<char, Vec<(Token, Span)>, Error = Simple<char>> {
let ctrl = one_of("(){};,:?!".chars()).map(Token::Ctrl); let ctrl = one_of("(){};,:?!".chars()).map(Token::Ctrl);
let ident = text::ident().map(|ident: String| match ident.as_str() { fn ident() -> impl Parser<char, String, Error = Simple<char>> + Copy + Clone {
filter(|c: &char| c.is_ascii_alphabetic() || *c == '_')
.map(Some)
.chain::<char, Vec<_>, _>(
filter(|c: &char| c.is_ascii_alphanumeric() || *c == '_' || *c == '.').repeated(),
)
.collect()
}
let ident = ident().map(|ident: String| match ident.as_str() {
"import" => Token::Import, "import" => Token::Import,
"export" => Token::Export, "export" => Token::Export,
"fn" => Token::Fn, "fn" => Token::Fn,

View File

@@ -2,6 +2,7 @@ use ariadne::{Color, Label, Report, ReportKind, Source};
use std::collections::HashMap; use std::collections::HashMap;
use crate::ast; use crate::ast;
use crate::intrinsics::Intrinsics;
use crate::Span; use crate::Span;
use ast::Type::*; use ast::Type::*;
@@ -22,6 +23,7 @@ pub fn tc_script(script: &mut ast::Script, source: &str) -> Result<()> {
local_vars: HashMap::new(), local_vars: HashMap::new(),
block_stack: Vec::new(), block_stack: Vec::new(),
return_type: None, return_type: None,
intrinsics: Intrinsics::new(),
}; };
let mut result = Ok(()); let mut result = Ok(());
@@ -164,6 +166,7 @@ struct Context<'a> {
local_vars: Vars, local_vars: Vars,
block_stack: Vec<String>, block_stack: Vec<String>,
return_type: Option<ast::Type>, return_type: Option<ast::Type>,
intrinsics: Intrinsics,
} }
fn report_duplicate_definition( fn report_duplicate_definition(
@@ -382,10 +385,7 @@ fn tc_expression(context: &mut Context, expr: &mut ast::Expression) -> Result<()
ast::Expr::I64Const(_) => Some(ast::Type::I64), ast::Expr::I64Const(_) => Some(ast::Type::I64),
ast::Expr::F32Const(_) => Some(ast::Type::F32), ast::Expr::F32Const(_) => Some(ast::Type::F32),
ast::Expr::F64Const(_) => Some(ast::Type::F64), ast::Expr::F64Const(_) => Some(ast::Type::F64),
ast::Expr::UnaryOp { ast::Expr::UnaryOp { op, ref mut value } => {
op,
ref mut value,
} => {
tc_expression(context, value)?; tc_expression(context, value)?;
if value.type_.is_none() { if value.type_.is_none() {
return expected_type(&value.span, context.source); return expected_type(&value.span, context.source);
@@ -395,7 +395,15 @@ fn tc_expression(context: &mut Context, expr: &mut ast::Expression) -> Result<()
Some(match (value.type_.unwrap(), op) { Some(match (value.type_.unwrap(), op) {
(t, Negate) => t, (t, Negate) => t,
(I32 | I64, Not) => I32, (I32 | I64, Not) => I32,
(_, Not) => return type_mismatch(Some(I32), &expr.span, value.type_, &value.span, context.source) (_, Not) => {
return type_mismatch(
Some(I32),
&expr.span,
value.type_,
&value.span,
context.source,
)
}
}) })
} }
ast::Expr::BinOp { ast::Expr::BinOp {
@@ -557,46 +565,43 @@ fn tc_expression(context: &mut Context, expr: &mut ast::Expression) -> Result<()
} => { } => {
for param in params.iter_mut() { for param in params.iter_mut() {
tc_expression(context, param)?; tc_expression(context, param)?;
if param.type_.is_none() {
return expected_type(&param.span, context.source);
}
} }
if let Some((ptypes, rtype)) = context if let Some(type_map) = context
.functions .functions
.get(name) .get(name)
.map(|fnc| (fnc.params.as_slice(), fnc.type_)) .map(|fnc| HashMap::from_iter([(fnc.params.clone(), fnc.type_)]))
.or_else(|| builtin_function_types(name)) .or_else(|| context.intrinsics.find_types(name))
{ {
if params.len() != ptypes.len() { if let Some(rtype) =
Report::build(ReportKind::Error, (), expr.span.start) type_map.get(&params.iter().map(|p| p.type_.unwrap()).collect::<Vec<_>>())
.with_message(format!( {
"Expected {} parameters but found {}", *rtype
ptypes.len(), } else {
params.len() let mut report = Report::build(ReportKind::Error, (), expr.span.start)
)) .with_message("No matching function found");
.with_label( for (params, rtype) in type_map {
Label::new(expr.span.clone()) let param_str: Vec<_> = params.into_iter().map(|t| t.to_string()).collect();
.with_message(format!( let msg = format!(
"Expected {} parameters but found {}", "Found {}({}){}",
ptypes.len(), name,
params.len() param_str.join(", "),
)) if let Some(rtype) = rtype {
.with_color(Color::Red), format!(" -> {}", rtype)
) } else {
String::new()
}
);
report = report.with_label(Label::new(expr.span.clone()).with_message(msg));
}
report
.finish() .finish()
.eprint(Source::from(context.source)) .eprint(Source::from(context.source))
.unwrap(); .unwrap();
return Err(()); return Err(());
} }
for (ptype, param) in ptypes.iter().zip(params.iter()) {
if param.type_ != Some(*ptype) {
return type_mismatch(
Some(*ptype),
&expr.span,
param.type_,
&param.span,
context.source,
);
}
}
rtype
} else { } else {
Report::build(ReportKind::Error, (), expr.span.start) Report::build(ReportKind::Error, (), expr.span.start)
.with_message(format!("Unknown function {}", name)) .with_message(format!("Unknown function {}", name))
@@ -718,13 +723,15 @@ fn tc_const(expr: &mut ast::Expression, source: &str) -> Result<()> {
use ast::Expr::*; use ast::Expr::*;
expr.type_ = Some(match expr.expr { expr.type_ = Some(match expr.expr {
I32Const(_) => I32, I32Const(_) => I32,
I64Const(_) => I64,
F32Const(_) => F32, F32Const(_) => F32,
F64Const(_) => F64,
_ => { _ => {
Report::build(ReportKind::Error, (), expr.span.start) Report::build(ReportKind::Error, (), expr.span.start)
.with_message("Expected I32 constant") .with_message("Expected constant value")
.with_label( .with_label(
Label::new(expr.span.clone()) Label::new(expr.span.clone())
.with_message("Expected I32 constant") .with_message("Expected constant value")
.with_color(Color::Red), .with_color(Color::Red),
) )
.finish() .finish()
@@ -735,15 +742,3 @@ fn tc_const(expr: &mut ast::Expression, source: &str) -> Result<()> {
}); });
Ok(()) Ok(())
} }
fn builtin_function_types(name: &str) -> Option<(&'static [ast::Type], Option<ast::Type>)> {
use ast::Type::*;
let types: (&'static [ast::Type], Option<ast::Type>) = match name {
"sqrt" => (&[F32], Some(F32)),
"abs" => (&[F32], Some(F32)),
"min" => (&[F32, F32], Some(F32)),
"max" => (&[F32, F32], Some(F32)),
_ => return None,
};
Some(types)
}