implement table of intrinsics

This commit is contained in:
2021-11-09 22:16:40 +01:00
parent 41ec5a770f
commit e9f115ee95
6 changed files with 229 additions and 92 deletions

View File

@@ -1,3 +1,5 @@
use std::fmt;
use crate::Span;
#[derive(Debug)]
@@ -29,7 +31,11 @@ pub enum ImportType {
type_: Type,
mutable: bool,
},
Function { name: String, params: Vec<Type>, result: Option<Type> }
Function {
name: String,
params: Vec<Type>,
result: Option<Type>,
},
}
#[derive(Debug)]
@@ -38,7 +44,7 @@ pub struct GlobalVar {
pub name: String,
pub value: Expression,
pub type_: Option<Type>,
pub mutable: bool
pub mutable: bool,
}
#[derive(Debug)]
@@ -130,10 +136,10 @@ pub enum Expr {
If {
condition: Box<Expression>,
if_true: Box<Expression>,
if_false: Option<Box<Expression>>
if_false: Option<Box<Expression>>,
},
Return {
value: Option<Box<Expression>>
value: Option<Box<Expression>>,
},
Error,
}
@@ -151,7 +157,7 @@ impl Expr {
#[derive(Debug, Clone, Copy)]
pub enum UnaryOp {
Negate,
Not
Not,
}
#[derive(Debug, Clone, Copy)]
@@ -188,3 +194,14 @@ pub enum Type {
F32,
F64,
}
impl fmt::Display for Type {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Type::I32 => write!(f, "i32"),
Type::I64 => write!(f, "i64"),
Type::F32 => write!(f, "f32"),
Type::F64 => write!(f, "f64"),
}
}
}

View File

@@ -6,7 +6,7 @@ use wasm_encoder::{
ValType,
};
use crate::ast;
use crate::{ast, intrinsics::Intrinsics};
pub fn emit(script: &ast::Script) -> Vec<u8> {
let mut module = Module::new();
@@ -94,6 +94,8 @@ pub fn emit(script: &ast::Script) -> Vec<u8> {
let mut exports = ExportSection::new();
let mut code = CodeSection::new();
let intrinsics = Intrinsics::new();
for func in script.functions.iter() {
function_map.insert(func.name.clone(), function_map.len() as u32);
}
@@ -108,7 +110,7 @@ pub fn emit(script: &ast::Script) -> Vec<u8> {
);
}
code.function(&emit_function(func, &globals, &function_map));
code.function(&emit_function(func, &globals, &function_map, &intrinsics));
}
module.section(&functions);
@@ -158,6 +160,8 @@ fn const_instr(expr: &ast::Expression) -> Instruction {
match expr.expr {
ast::Expr::I32Const(v) => Instruction::I32Const(v),
ast::Expr::F32Const(v) => Instruction::F32Const(v),
ast::Expr::I64Const(v) => Instruction::I64Const(v),
ast::Expr::F64Const(v) => Instruction::F64Const(v),
_ => unreachable!(),
}
}
@@ -169,12 +173,14 @@ struct FunctionContext<'a> {
locals: &'a HashMap<String, u32>,
labels: Vec<String>,
deferred_inits: HashMap<&'a str, &'a ast::Expression>,
intrinsics: &'a Intrinsics,
}
fn emit_function(
func: &ast::Function,
globals: &HashMap<&str, u32>,
functions: &HashMap<String, u32>,
intrinsics: &Intrinsics,
) -> Function {
let mut locals = Vec::new();
collect_locals_expr(&func.body, &mut locals);
@@ -199,6 +205,7 @@ fn emit_function(
locals: &local_map,
labels: vec![],
deferred_inits: HashMap::new(),
intrinsics,
};
emit_expression(&mut context, &func.body);
@@ -563,6 +570,7 @@ fn emit_expression<'a>(ctx: &mut FunctionContext<'a>, expr: &'a ast::Expression)
for param in params {
emit_expression(ctx, param);
}
if let Some(index) = ctx.functions.get(name) {
ctx.function.instruction(&Instruction::Call(*index));
} else {
@@ -571,7 +579,7 @@ fn emit_expression<'a>(ctx: &mut FunctionContext<'a>, expr: &'a ast::Expression)
types.push(param.type_.unwrap());
}
ctx.function
.instruction(&builtin_function(name, &types).unwrap());
.instruction(&ctx.intrinsics.get_instr(name, &types).unwrap());
}
}
ast::Expr::Select {
@@ -634,15 +642,3 @@ fn map_block_type(t: Option<ast::Type>) -> BlockType {
BlockType::Empty
}
}
fn builtin_function(name: &str, params: &[ast::Type]) -> Option<Instruction<'static>> {
use ast::Type::*;
let inst = match (name, params) {
("sqrt", &[F32]) => Instruction::F32Sqrt,
("abs", &[F32]) => Instruction::F32Abs,
("min", &[F32, F32]) => Instruction::F32Min,
("max", &[F32, F32]) => Instruction::F32Max,
_ => return None,
};
Some(inst)
}

140
src/intrinsics.rs Normal file
View File

@@ -0,0 +1,140 @@
use crate::ast::Type;
use std::collections::HashMap;
use wasm_encoder as enc;
pub struct Intrinsics(HashMap<String, HashMap<Vec<Type>, (Type, enc::Instruction<'static>)>>);
impl Intrinsics {
pub fn new() -> Intrinsics {
let mut i = Intrinsics(HashMap::new());
i.add_instructions();
i
}
pub fn find_types(
&self,
name: &str,
) -> Option<HashMap<Vec<Type>, Option<Type>>> {
self.0.get(name).map(|types| {
types
.iter()
.map(|(params, (ret, _))| (params.clone(), Some(*ret)))
.collect()
})
}
pub fn get_instr(&self, name: &str, params: &[Type]) -> Option<enc::Instruction<'static>> {
self.0
.get(name)
.and_then(|types| types.get(params))
.map(|(_, i)| i.clone())
}
fn add_instructions(&mut self) {
use enc::Instruction as I;
use Type::*;
self.inst("i32.rotl", &[I32, I32], I32, I::I32Rotl);
self.inst("i32.rotr", &[I32, I32], I32, I::I32Rotr);
self.inst("i32.clz", &[I32], I32, I::I32Clz);
self.inst("i32.ctz", &[I32], I32, I::I32Ctz);
self.inst("i32.popcnt", &[I32], I32, I::I32Popcnt);
self.inst("i64.rotl", &[I64, I64], I64, I::I64Rotl);
self.inst("i64.rotr", &[I64, I64], I64, I::I64Rotr);
self.inst("i64.clz", &[I64], I64, I::I64Clz);
self.inst("i64.ctz", &[I64], I64, I::I64Ctz);
self.inst("i64.popcnt", &[I64], I64, I::I64Popcnt);
self.inst("f32/sqrt", &[F32], F32, I::F32Sqrt);
self.inst("f32/min", &[F32, F32], F32, I::F32Min);
self.inst("f32/max", &[F32, F32], F32, I::F32Max);
self.inst("f32/ceil", &[F32], F32, I::F32Ceil);
self.inst("f32/floor", &[F32], F32, I::F32Floor);
self.inst("f32/trunc", &[F32], F32, I::F32Trunc);
self.inst("f32/nearest", &[F32], F32, I::F32Nearest);
self.inst("f32/abs", &[F32], F32, I::F32Abs);
self.inst("f32.copysign", &[F32, F32], F32, I::F32Copysign);
self.inst("f64/sqrt", &[F64], F64, I::F64Sqrt);
self.inst("f64/min", &[F64, F64], F64, I::F64Min);
self.inst("f64/max", &[F64, F64], F64, I::F64Max);
self.inst("f64/ceil", &[F64], F64, I::F64Ceil);
self.inst("f64/floor", &[F64], F64, I::F64Floor);
self.inst("f64/trunc", &[F64], F64, I::F64Trunc);
self.inst("f64/nearest", &[F64], F64, I::F64Nearest);
self.inst("f64/abs", &[F64], F64, I::F64Abs);
self.inst("f64.copysign", &[F64, F64], F64, I::F64Copysign);
self.inst("i32.wrap_i64", &[I64], I32, I::I32WrapI64);
self.inst("i64.extend_i32_s", &[I32], I64, I::I64ExtendI32S);
self.inst("i64.extend_i32_u", &[I32], I64, I::I64ExtendI32U);
self.inst("i32.trunc_f32_s", &[F32], I32, I::I32TruncF32S);
self.inst("i32.trunc_f64_s", &[F64], I32, I::I32TruncF64S);
self.inst("i64.trunc_f32_s", &[F32], I64, I::I64TruncF32S);
self.inst("i64.trunc_f64_s", &[F64], I64, I::I64TruncF64S);
self.inst("i32.trunc_f32_u", &[F32], I32, I::I32TruncF32U);
self.inst("i32.trunc_f64_u", &[F64], I32, I::I32TruncF64U);
self.inst("i64.trunc_f32_u", &[F32], I64, I::I64TruncF32U);
self.inst("i64.trunc_f64_u", &[F64], I64, I::I64TruncF64U);
self.inst("f32.demote_f64", &[F64], F32, I::F32DemoteF64);
self.inst("f64.promote_f32", &[F32], F64, I::F64PromoteF32);
self.inst("f32.convert_i32_s", &[I32], F32, I::F32ConvertI32S);
self.inst("f32.convert_i64_s", &[I64], F32, I::F32ConvertI32S);
self.inst("f64.convert_i32_s", &[I32], F64, I::F32ConvertI32S);
self.inst("f64.convert_i64_s", &[I64], F64, I::F32ConvertI32S);
self.inst("f32.convert_i32_u", &[I32], F32, I::F32ConvertI32U);
self.inst("f32.convert_i64_u", &[I64], F32, I::F32ConvertI32U);
self.inst("f64.convert_i32_u", &[I32], F64, I::F32ConvertI32U);
self.inst("f64.convert_i64_u", &[I64], F64, I::F32ConvertI32U);
self.inst("i32.reinterpret_f32", &[F32], I32, I::I32ReinterpretF32);
self.inst("i64.reinterpret_f64", &[F64], I64, I::I64ReinterpretF64);
self.inst("f32.reinterpret_i32", &[I32], F32, I::F32ReinterpretI32);
self.inst("f64.reinterpret_i64", &[I64], F64, I::F64ReinterpretI64);
self.inst("i32.extend8_s", &[I32], I32, I::I32Extend8S);
self.inst("i32.extend16_s", &[I32], I32, I::I32Extend16S);
self.inst("i64.extend8_s", &[I64], I64, I::I64Extend8S);
self.inst("i64.extend16_s", &[I64], I64, I::I64Extend16S);
self.inst("i64.extend32_s", &[I64], I64, I::I64Extend32S);
self.inst("i32.trunc_sat_f32_s", &[F32], I32, I::I32TruncSatF32S);
self.inst("i32.trunc_sat_f32_u", &[F32], I32, I::I32TruncSatF32U);
self.inst("i32.trunc_sat_f64_s", &[F64], I32, I::I32TruncSatF64S);
self.inst("i32.trunc_sat_f64_u", &[F64], I32, I::I32TruncSatF64U);
self.inst("i64.trunc_sat_f32_s", &[F32], I64, I::I64TruncSatF32S);
self.inst("i64.trunc_sat_f32_u", &[F32], I64, I::I64TruncSatF32U);
self.inst("i64.trunc_sat_f64_s", &[F64], I64, I::I64TruncSatF64S);
self.inst("i64.trunc_sat_f64_u", &[F64], I64, I::I64TruncSatF64U);
}
fn inst(&mut self, name: &str, params: &[Type], ret: Type, ins: enc::Instruction<'static>) {
if let Some(slash_idx) = name.find('/') {
self.insert(name[(slash_idx + 1)..].to_string(), params, ret, &ins);
let mut full_name = name[..slash_idx].to_string();
full_name.push('.');
full_name += &name[(slash_idx + 1)..];
self.insert(full_name, params, ret, &ins);
} else {
self.insert(name.to_string(), params, ret, &ins);
}
}
fn insert(
&mut self,
name: String,
params: &[Type],
ret: Type,
ins: &enc::Instruction<'static>,
) {
self.0
.entry(name)
.or_default()
.insert(params.to_vec(), (ret, ins.clone()));
}
}

View File

@@ -7,6 +7,7 @@ mod constfold;
mod emit;
mod parser;
mod typecheck;
mod intrinsics;
type Span = std::ops::Range<usize>;
@@ -35,26 +36,5 @@ fn main() -> Result<()> {
filename.set_extension("wasm");
File::create(filename)?.write_all(&wasm)?;
println!("Size of code section: {} bytes", code_section_size(&wasm)?);
Ok(())
}
fn code_section_size(wasm: &[u8]) -> Result<usize> {
for payload in wasmparser::Parser::new(0).parse_all(wasm) {
match payload? {
wasmparser::Payload::CodeSectionStart { range, .. } => {
let size = range.end - range.start;
let section_header_size = match size {
0..=127 => 2,
128..=16383 => 3,
_ => 4,
};
return Ok(size + section_header_size);
}
_ => (),
}
}
bail!("No code section found");
}

View File

@@ -201,7 +201,16 @@ fn lexer() -> impl Parser<char, Vec<(Token, Span)>, Error = Simple<char>> {
let ctrl = one_of("(){};,:?!".chars()).map(Token::Ctrl);
let ident = text::ident().map(|ident: String| match ident.as_str() {
fn ident() -> impl Parser<char, String, Error = Simple<char>> + Copy + Clone {
filter(|c: &char| c.is_ascii_alphabetic() || *c == '_')
.map(Some)
.chain::<char, Vec<_>, _>(
filter(|c: &char| c.is_ascii_alphanumeric() || *c == '_' || *c == '.').repeated(),
)
.collect()
}
let ident = ident().map(|ident: String| match ident.as_str() {
"import" => Token::Import,
"export" => Token::Export,
"fn" => Token::Fn,

View File

@@ -2,6 +2,7 @@ use ariadne::{Color, Label, Report, ReportKind, Source};
use std::collections::HashMap;
use crate::ast;
use crate::intrinsics::Intrinsics;
use crate::Span;
use ast::Type::*;
@@ -22,6 +23,7 @@ pub fn tc_script(script: &mut ast::Script, source: &str) -> Result<()> {
local_vars: HashMap::new(),
block_stack: Vec::new(),
return_type: None,
intrinsics: Intrinsics::new(),
};
let mut result = Ok(());
@@ -164,6 +166,7 @@ struct Context<'a> {
local_vars: Vars,
block_stack: Vec<String>,
return_type: Option<ast::Type>,
intrinsics: Intrinsics,
}
fn report_duplicate_definition(
@@ -382,10 +385,7 @@ fn tc_expression(context: &mut Context, expr: &mut ast::Expression) -> Result<()
ast::Expr::I64Const(_) => Some(ast::Type::I64),
ast::Expr::F32Const(_) => Some(ast::Type::F32),
ast::Expr::F64Const(_) => Some(ast::Type::F64),
ast::Expr::UnaryOp {
op,
ref mut value,
} => {
ast::Expr::UnaryOp { op, ref mut value } => {
tc_expression(context, value)?;
if value.type_.is_none() {
return expected_type(&value.span, context.source);
@@ -395,7 +395,15 @@ fn tc_expression(context: &mut Context, expr: &mut ast::Expression) -> Result<()
Some(match (value.type_.unwrap(), op) {
(t, Negate) => t,
(I32 | I64, Not) => I32,
(_, Not) => return type_mismatch(Some(I32), &expr.span, value.type_, &value.span, context.source)
(_, Not) => {
return type_mismatch(
Some(I32),
&expr.span,
value.type_,
&value.span,
context.source,
)
}
})
}
ast::Expr::BinOp {
@@ -557,46 +565,43 @@ fn tc_expression(context: &mut Context, expr: &mut ast::Expression) -> Result<()
} => {
for param in params.iter_mut() {
tc_expression(context, param)?;
if param.type_.is_none() {
return expected_type(&param.span, context.source);
}
}
if let Some((ptypes, rtype)) = context
if let Some(type_map) = context
.functions
.get(name)
.map(|fnc| (fnc.params.as_slice(), fnc.type_))
.or_else(|| builtin_function_types(name))
.map(|fnc| HashMap::from_iter([(fnc.params.clone(), fnc.type_)]))
.or_else(|| context.intrinsics.find_types(name))
{
if params.len() != ptypes.len() {
Report::build(ReportKind::Error, (), expr.span.start)
.with_message(format!(
"Expected {} parameters but found {}",
ptypes.len(),
params.len()
))
.with_label(
Label::new(expr.span.clone())
.with_message(format!(
"Expected {} parameters but found {}",
ptypes.len(),
params.len()
))
.with_color(Color::Red),
)
if let Some(rtype) =
type_map.get(&params.iter().map(|p| p.type_.unwrap()).collect::<Vec<_>>())
{
*rtype
} else {
let mut report = Report::build(ReportKind::Error, (), expr.span.start)
.with_message("No matching function found");
for (params, rtype) in type_map {
let param_str: Vec<_> = params.into_iter().map(|t| t.to_string()).collect();
let msg = format!(
"Found {}({}){}",
name,
param_str.join(", "),
if let Some(rtype) = rtype {
format!(" -> {}", rtype)
} else {
String::new()
}
);
report = report.with_label(Label::new(expr.span.clone()).with_message(msg));
}
report
.finish()
.eprint(Source::from(context.source))
.unwrap();
return Err(());
}
for (ptype, param) in ptypes.iter().zip(params.iter()) {
if param.type_ != Some(*ptype) {
return type_mismatch(
Some(*ptype),
&expr.span,
param.type_,
&param.span,
context.source,
);
}
}
rtype
} else {
Report::build(ReportKind::Error, (), expr.span.start)
.with_message(format!("Unknown function {}", name))
@@ -718,13 +723,15 @@ fn tc_const(expr: &mut ast::Expression, source: &str) -> Result<()> {
use ast::Expr::*;
expr.type_ = Some(match expr.expr {
I32Const(_) => I32,
I64Const(_) => I64,
F32Const(_) => F32,
F64Const(_) => F64,
_ => {
Report::build(ReportKind::Error, (), expr.span.start)
.with_message("Expected I32 constant")
.with_message("Expected constant value")
.with_label(
Label::new(expr.span.clone())
.with_message("Expected I32 constant")
.with_message("Expected constant value")
.with_color(Color::Red),
)
.finish()
@@ -735,15 +742,3 @@ fn tc_const(expr: &mut ast::Expression, source: &str) -> Result<()> {
});
Ok(())
}
fn builtin_function_types(name: &str) -> Option<(&'static [ast::Type], Option<ast::Type>)> {
use ast::Type::*;
let types: (&'static [ast::Type], Option<ast::Type>) = match name {
"sqrt" => (&[F32], Some(F32)),
"abs" => (&[F32], Some(F32)),
"min" => (&[F32, F32], Some(F32)),
"max" => (&[F32, F32], Some(F32)),
_ => return None,
};
Some(types)
}