Browse Source

cranelift-isle: trie construction and IR cleanups (#5171)

One big change here is to stop using `Term::extractor_sig`, which was
the only call that used a `TypeEnv`. However that function only uses
type information to construct the fully-qualified name of the extractor,
which is not used when building the IR. So removing it and removing the
now-unused `typeenv` parameters removes all uses of `TypeEnv` from the
`ir` and `trie` modules.

In addition, this completes the changes started in "More consistent use
of `add_inst`" (e63771f2d9), by always
using `add_inst` to get an `InstId`.

I also removed a number of unnecessary intermediate allocations.
pull/5176/head
Jamey Sharp 2 years ago
committed by GitHub
parent
commit
033758daaf
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      cranelift/isle/isle/src/compile.rs
  2. 140
      cranelift/isle/isle/src/ir.rs
  3. 98
      cranelift/isle/isle/src/trie.rs

2
cranelift/isle/isle/src/compile.rs

@ -8,6 +8,6 @@ pub fn compile(defs: &ast::Defs, options: &codegen::CodegenOptions) -> Result<St
let mut typeenv = sema::TypeEnv::from_ast(defs)?;
let termenv = sema::TermEnv::from_ast(&mut typeenv, defs)?;
crate::overlap::check(&mut typeenv, &termenv)?;
let tries = trie::build_tries(&typeenv, &termenv);
let tries = trie::build_tries(&termenv);
Ok(codegen::codegen(&typeenv, &termenv, &tries, options))
}

140
cranelift/isle/isle/src/ir.rs

@ -285,20 +285,17 @@ impl PatternSequence {
arg_tys: &[TypeId],
variant: VariantId,
) -> Vec<Value> {
let inst = InstId(self.insts.len());
let mut outs = vec![];
for (i, _arg_ty) in arg_tys.iter().enumerate() {
let val = Value::Pattern { inst, output: i };
outs.push(val);
}
let arg_tys = arg_tys.iter().cloned().collect();
self.add_inst(PatternInst::MatchVariant {
let outputs = arg_tys.len();
let arg_tys = arg_tys.into();
let inst = self.add_inst(PatternInst::MatchVariant {
input,
input_ty,
arg_tys,
variant,
});
outs
(0..outputs)
.map(|output| Value::Pattern { inst, output })
.collect()
}
fn add_extract(
@ -310,14 +307,8 @@ impl PatternSequence {
infallible: bool,
multi: bool,
) -> Vec<Value> {
let inst = InstId(self.insts.len());
let mut outs = vec![];
for i in 0..output_tys.len() {
let val = Value::Pattern { inst, output: i };
outs.push(val);
}
let output_tys = output_tys.iter().cloned().collect();
self.add_inst(PatternInst::Extract {
let outputs = output_tys.len();
let inst = self.add_inst(PatternInst::Extract {
inputs,
input_tys,
output_tys,
@ -325,7 +316,9 @@ impl PatternSequence {
infallible,
multi,
});
outs
(0..outputs)
.map(|output| Value::Pattern { inst, output })
.collect()
}
fn add_expr_seq(&mut self, seq: ExprSequence, output: Value, output_ty: TypeId) -> Value {
@ -344,7 +337,6 @@ impl PatternSequence {
fn gen_pattern(
&mut self,
input: ValueOrArgs,
typeenv: &TypeEnv,
termenv: &TermEnv,
pat: &Pattern,
vars: &mut StableMap<VarId, Value>,
@ -356,8 +348,7 @@ impl PatternSequence {
if let Some(v) = input.to_value() {
vars.insert(var, v);
}
let root_term = self.gen_pattern(input, typeenv, termenv, &*subpat, vars);
root_term
self.gen_pattern(input, termenv, subpat, vars);
}
&Pattern::Var(ty, var) => {
// Assert that the value matches the existing bound var.
@ -394,32 +385,15 @@ impl PatternSequence {
let arg_tys = &termdata.arg_tys[..];
for (i, subpat) in args.iter().enumerate() {
let value = self.add_arg(i, arg_tys[i]);
self.gen_pattern(
ValueOrArgs::Value(value),
typeenv,
termenv,
subpat,
vars,
);
self.gen_pattern(ValueOrArgs::Value(value), termenv, subpat, vars);
}
}
ValueOrArgs::Value(input) => {
// Determine whether the term has an external extractor or not.
let termdata = &termenv.terms[term.index()];
let arg_tys = &termdata.arg_tys[..];
match &termdata.kind {
let arg_values = match &termdata.kind {
TermKind::EnumVariant { variant } => {
let arg_values =
self.add_match_variant(input, ty, arg_tys, *variant);
for (subpat, value) in args.iter().zip(arg_values.into_iter()) {
self.gen_pattern(
ValueOrArgs::Value(value),
typeenv,
termenv,
subpat,
vars,
);
}
self.add_match_variant(input, ty, &termdata.arg_tys, *variant)
}
TermKind::Decl {
extractor_kind: None,
@ -434,50 +408,36 @@ impl PatternSequence {
panic!("Should have been expanded away")
}
TermKind::Decl {
extractor_kind: Some(ExtractorKind::ExternalExtractor { .. }),
multi,
extractor_kind:
Some(ExtractorKind::ExternalExtractor { infallible, .. }),
..
} => {
let ext_sig = termdata.extractor_sig(typeenv).unwrap();
// Evaluate all `input` args.
let mut inputs = vec![];
let mut input_tys = vec![];
let mut output_tys = vec![];
let mut output_pats = vec![];
inputs.push(input);
input_tys.push(termdata.ret_ty);
for arg in args {
output_tys.push(arg.ty());
output_pats.push(arg);
}
let inputs = vec![input];
let input_tys = vec![termdata.ret_ty];
let output_tys = args.iter().map(|arg| arg.ty()).collect();
// Invoke the extractor.
let arg_values = self.add_extract(
self.add_extract(
inputs,
input_tys,
output_tys,
term,
ext_sig.infallible,
ext_sig.multi,
);
for (pat, &val) in output_pats.iter().zip(arg_values.iter()) {
self.gen_pattern(
ValueOrArgs::Value(val),
typeenv,
termenv,
pat,
vars,
);
}
*infallible && !*multi,
*multi,
)
}
};
for (pat, val) in args.iter().zip(arg_values) {
self.gen_pattern(ValueOrArgs::Value(val), termenv, pat, vars);
}
}
}
}
&Pattern::And(_ty, ref children) => {
for child in children {
self.gen_pattern(input, typeenv, termenv, child, vars);
self.gen_pattern(input, termenv, child, vars);
}
}
&Pattern::Wildcard(_ty) => {
@ -506,11 +466,10 @@ impl ExprSequence {
fn add_create_variant(
&mut self,
inputs: &[(Value, TypeId)],
inputs: Vec<(Value, TypeId)>,
ty: TypeId,
variant: VariantId,
) -> Value {
let inputs = inputs.iter().cloned().collect();
let inst = self.add_inst(ExprInst::CreateVariant {
inputs,
ty,
@ -521,13 +480,12 @@ impl ExprSequence {
fn add_construct(
&mut self,
inputs: &[(Value, TypeId)],
inputs: Vec<(Value, TypeId)>,
ty: TypeId,
term: TermId,
infallible: bool,
multi: bool,
) -> Value {
let inputs = inputs.iter().cloned().collect();
let inst = self.add_inst(ExprInst::Construct {
inputs,
ty,
@ -551,7 +509,6 @@ impl ExprSequence {
/// term ID, if any.
fn gen_expr(
&mut self,
typeenv: &TypeEnv,
termenv: &TermEnv,
expr: &Expr,
vars: &StableMap<VarId, Value>,
@ -567,22 +524,22 @@ impl ExprSequence {
} => {
let mut vars = vars.clone();
for &(var, _var_ty, ref var_expr) in bindings {
let var_value = self.gen_expr(typeenv, termenv, &*var_expr, &vars);
let var_value = self.gen_expr(termenv, var_expr, &vars);
vars.insert(var, var_value);
}
self.gen_expr(typeenv, termenv, body, &vars)
self.gen_expr(termenv, body, &vars)
}
&Expr::Var(_ty, var_id) => vars.get(&var_id).cloned().unwrap(),
&Expr::Term(ty, term, ref arg_exprs) => {
let termdata = &termenv.terms[term.index()];
let mut arg_values_tys = vec![];
for (arg_ty, arg_expr) in termdata.arg_tys.iter().cloned().zip(arg_exprs.iter()) {
arg_values_tys
.push((self.gen_expr(typeenv, termenv, &*arg_expr, &vars), arg_ty));
}
let arg_values_tys = arg_exprs
.iter()
.map(|arg_expr| self.gen_expr(termenv, arg_expr, vars))
.zip(termdata.arg_tys.iter().copied())
.collect();
match &termdata.kind {
TermKind::EnumVariant { variant } => {
self.add_create_variant(&arg_values_tys[..], ty, *variant)
self.add_create_variant(arg_values_tys, ty, *variant)
}
TermKind::Decl {
constructor_kind: Some(ConstructorKind::InternalConstructor),
@ -590,7 +547,7 @@ impl ExprSequence {
..
} => {
self.add_construct(
&arg_values_tys[..],
arg_values_tys,
ty,
term,
/* infallible = */ false,
@ -604,7 +561,7 @@ impl ExprSequence {
..
} => {
self.add_construct(
&arg_values_tys[..],
arg_values_tys,
ty,
term,
/* infallible = */ !pure,
@ -622,16 +579,13 @@ impl ExprSequence {
}
/// Build a sequence from a rule.
pub fn lower_rule(
tyenv: &TypeEnv,
termenv: &TermEnv,
rule: RuleId,
) -> (PatternSequence, ExprSequence) {
pub fn lower_rule(termenv: &TermEnv, rule: RuleId) -> (PatternSequence, ExprSequence) {
let mut pattern_seq: PatternSequence = Default::default();
let mut expr_seq: ExprSequence = Default::default();
expr_seq.pos = termenv.rules[rule.index()].pos;
let ruledata = &termenv.rules[rule.index()];
expr_seq.pos = ruledata.pos;
let mut vars = StableMap::new();
let root_term = ruledata
.lhs
@ -643,7 +597,6 @@ pub fn lower_rule(
// Lower the pattern, starting from the root input value.
pattern_seq.gen_pattern(
ValueOrArgs::ImplicitTermFromArgs(root_term),
tyenv,
termenv,
&ruledata.lhs,
&mut vars,
@ -653,13 +606,12 @@ pub fn lower_rule(
// `PatternInst::Expr` for the sub-exprs (right-hand sides).
for iflet in &ruledata.iflets {
let mut subexpr_seq: ExprSequence = Default::default();
let subexpr_ret_value = subexpr_seq.gen_expr(tyenv, termenv, &iflet.rhs, &mut vars);
let subexpr_ret_value = subexpr_seq.gen_expr(termenv, &iflet.rhs, &mut vars);
subexpr_seq.add_return(iflet.rhs.ty(), subexpr_ret_value);
let pattern_value =
pattern_seq.add_expr_seq(subexpr_seq, subexpr_ret_value, iflet.rhs.ty());
pattern_seq.gen_pattern(
ValueOrArgs::Value(pattern_value),
tyenv,
termenv,
&iflet.lhs,
&mut vars,
@ -668,7 +620,7 @@ pub fn lower_rule(
// Lower the expression, making use of the bound variables
// from the pattern.
let rhs_root_val = expr_seq.gen_expr(tyenv, termenv, &ruledata.rhs, &vars);
let rhs_root_val = expr_seq.gen_expr(termenv, &ruledata.rhs, &vars);
// Return the root RHS value.
let output_ty = ruledata.rhs.ty();
expr_seq.add_return(output_ty, rhs_root_val);

98
cranelift/isle/isle/src/trie.rs

@ -1,14 +1,14 @@
//! Trie construction.
use crate::ir::{lower_rule, ExprSequence, PatternInst, PatternSequence};
use crate::ir::{lower_rule, ExprSequence, PatternInst};
use crate::log;
use crate::sema::{RuleId, TermEnv, TermId, TypeEnv};
use crate::sema::{TermEnv, TermId};
use std::collections::BTreeMap;
/// Construct the tries for each term.
pub fn build_tries(typeenv: &TypeEnv, termenv: &TermEnv) -> BTreeMap<TermId, TrieNode> {
let mut builder = TermFunctionsBuilder::new(typeenv, termenv);
builder.build();
pub fn build_tries(termenv: &TermEnv) -> BTreeMap<TermId, TrieNode> {
let mut builder = TermFunctionsBuilder::default();
builder.build(termenv);
log!("builder: {:?}", builder);
builder.finalize()
}
@ -280,91 +280,43 @@ impl TrieNode {
}
}
/// Builder context for one function in generated code corresponding
/// to one root input term.
///
/// A `TermFunctionBuilder` can correspond to the matching
/// control-flow and operations that we execute either when evaluating
/// *forward* on a term, trying to match left-hand sides against it
/// and transforming it into another term; or *backward* on a term,
/// trying to match another rule's left-hand side against an input to
/// produce the term in question (when the term is used in the LHS of
/// the calling term).
#[derive(Debug)]
struct TermFunctionBuilder {
trie: TrieNode,
}
impl TermFunctionBuilder {
fn new() -> Self {
TermFunctionBuilder {
trie: TrieNode::Empty,
}
}
fn add_rule(&mut self, prio: Prio, pattern_seq: PatternSequence, expr_seq: ExprSequence) {
let symbols = pattern_seq
.insts
.into_iter()
.map(|op| TrieSymbol::Match { op })
.chain(std::iter::once(TrieSymbol::EndOfMatch));
self.trie.insert(prio, symbols, expr_seq);
}
fn sort_trie(&mut self) {
self.trie.sort();
}
#[derive(Debug, Default)]
struct TermFunctionsBuilder {
builders_by_term: BTreeMap<TermId, TrieNode>,
}
#[derive(Debug)]
struct TermFunctionsBuilder<'a> {
typeenv: &'a TypeEnv,
termenv: &'a TermEnv,
builders_by_term: BTreeMap<TermId, TermFunctionBuilder>,
}
impl<'a> TermFunctionsBuilder<'a> {
fn new(typeenv: &'a TypeEnv, termenv: &'a TermEnv) -> Self {
log!("typeenv: {:?}", typeenv);
impl TermFunctionsBuilder {
fn build(&mut self, termenv: &TermEnv) {
log!("termenv: {:?}", termenv);
Self {
builders_by_term: BTreeMap::new(),
typeenv,
termenv,
}
}
fn build(&mut self) {
for rule in 0..self.termenv.rules.len() {
let rule = RuleId(rule);
let prio = self.termenv.rules[rule.index()].prio;
let (pattern, expr) = lower_rule(self.typeenv, self.termenv, rule);
let root_term = self.termenv.rules[rule.index()].lhs.root_term().unwrap();
for rule in termenv.rules.iter() {
let (pattern, expr) = lower_rule(termenv, rule.id);
let root_term = rule.lhs.root_term().unwrap();
log!(
"build:\n- rule {:?}\n- pattern {:?}\n- expr {:?}",
self.termenv.rules[rule.index()],
rule,
pattern,
expr
);
let symbols = pattern
.insts
.into_iter()
.map(|op| TrieSymbol::Match { op })
.chain(std::iter::once(TrieSymbol::EndOfMatch));
self.builders_by_term
.entry(root_term)
.or_insert_with(|| TermFunctionBuilder::new())
.add_rule(prio, pattern.clone(), expr.clone());
.or_insert(TrieNode::Empty)
.insert(rule.prio, symbols, expr);
}
for builder in self.builders_by_term.values_mut() {
builder.sort_trie();
builder.sort();
}
}
fn finalize(self) -> BTreeMap<TermId, TrieNode> {
let functions_by_term = self
.builders_by_term
.into_iter()
.map(|(term, builder)| (term, builder.trie))
.collect::<BTreeMap<_, _>>();
functions_by_term
self.builders_by_term
}
}

Loading…
Cancel
Save