Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 17 additions & 13 deletions src/analyze.rs
Original file line number Diff line number Diff line change
Expand Up @@ -402,17 +402,13 @@ impl<'tcx> Analyzer<'tcx> {
}
}

/// Computes the signature of the local function.
/// Computes the signature of the function using the given `body`.
///
/// This works like `self.tcx.fn_sig(local_def_id).instantiate_identity().skip_binder()`,
/// This works like `self.tcx.fn_sig(def_id).instantiate_identity().skip_binder()`,
/// but extracts parameter and return types directly from the given `body` to obtain a signature that
/// reflects potential type instantiations happened after `optimized_mir`.
pub fn local_fn_sig_with_body(
&self,
local_def_id: LocalDefId,
body: &mir::Body<'tcx>,
) -> mir_ty::FnSig<'tcx> {
let ty = self.tcx.type_of(local_def_id).instantiate_identity();
pub fn fn_sig_with_body(&self, def_id: DefId, body: &mir::Body<'tcx>) -> mir_ty::FnSig<'tcx> {
let ty = self.tcx.type_of(def_id).instantiate_identity();
let sig = if let mir_ty::TyKind::Closure(_, substs) = ty.kind() {
substs.as_closure().sig().skip_binder()
} else {
Expand All @@ -428,14 +424,14 @@ impl<'tcx> Analyzer<'tcx> {
)
}

/// Computes the signature of the local function.
/// Computes the signature of the function.
///
/// This works like `self.tcx.fn_sig(local_def_id).instantiate_identity().skip_binder()`,
/// This works like `self.tcx.fn_sig(def_id).instantiate_identity().skip_binder()`,
/// but extracts parameter and return types directly from [`mir::Body`] to obtain a signature that
/// reflects the actual type of lifted closure functions.
pub fn local_fn_sig(&self, local_def_id: LocalDefId) -> mir_ty::FnSig<'tcx> {
let body = self.tcx.optimized_mir(local_def_id);
self.local_fn_sig_with_body(local_def_id, body)
pub fn fn_sig(&self, def_id: DefId) -> mir_ty::FnSig<'tcx> {
let body = self.tcx.optimized_mir(def_id);
self.fn_sig_with_body(def_id, body)
}

fn extract_require_annot<T>(
Expand Down Expand Up @@ -487,4 +483,12 @@ impl<'tcx> Analyzer<'tcx> {
}
ensure_annot
}

/// Whether the given `def_id` corresponds to a method of one of the `Fn` traits.
fn is_fn_trait_method(&self, def_id: DefId) -> bool {
self.tcx
.trait_of_item(def_id)
.and_then(|trait_did| self.tcx.fn_trait_kind_from_def_id(trait_did))
.is_some()
}
}
40 changes: 35 additions & 5 deletions src/analyze/basic_block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ pub struct Analyzer<'tcx, 'ctx> {
}

impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
fn ctx(&self) -> &analyze::Analyzer<'tcx> {
&*self.ctx
}

fn is_defined(&self, local: Local) -> bool {
self.env.contains_local(local)
}
Expand All @@ -53,6 +57,10 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
visitor::ReborrowVisitor::new(self)
}

fn rust_call_visitor<'a>(&'a mut self) -> visitor::RustCallVisitor<'a, 'tcx, 'ctx> {
visitor::RustCallVisitor::new(self)
}

fn basic_block_ty(&self, bb: BasicBlock) -> &BasicBlockType {
self.ctx.basic_block_ty(self.local_def_id, bb)
}
Expand Down Expand Up @@ -568,12 +576,28 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
{
// TODO: handle const_fn_def on Env side
let func_ty = if let Some((def_id, args)) = func.const_fn_def() {
let param_env = self.tcx.param_env(self.local_def_id);
let instance = mir_ty::Instance::resolve(self.tcx, param_env, def_id, args).unwrap();
let resolved_def_id = if let Some(instance) = instance {
instance.def_id()
let resolved_def_id = if self.ctx.is_fn_trait_method(def_id) {
// When calling a closure via `Fn`/`FnMut`/`FnOnce` trait,
// we simply replace the def_id with the closure's function def_id.
// This skips shims, and makes self arguments mismatch. visitor::RustCallVisitor
// adjusts the arguments accordingly.
let mir_ty::TyKind::Closure(closure_def_id, _) = args.type_at(0).kind() else {
panic!("expected closure arg for fn trait");
};
tracing::debug!(?closure_def_id, "closure instance");
*closure_def_id
} else {
def_id
let param_env = self
.tcx
.param_env(self.local_def_id)
.with_reveal_all_normalized(self.tcx);
let instance =
mir_ty::Instance::resolve(self.tcx, param_env, def_id, args).unwrap();
if let Some(instance) = instance {
instance.def_id()
} else {
def_id
}
};
if def_id != resolved_def_id {
tracing::info!(?def_id, ?resolved_def_id, "resolve",);
Expand Down Expand Up @@ -671,6 +695,11 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
self.env.borrow_place(place, prophecy).into()
}

fn immut_borrow_place(&self, referent: mir::Place<'tcx>) -> rty::RefinedType<Var> {
let place = self.elaborate_place(&referent);
self.env.place_type(place).immut().into()
}

#[tracing::instrument(skip(self, lhs, rvalue))]
fn analyze_assignment(
&mut self,
Expand Down Expand Up @@ -754,6 +783,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
source_info: term.source_info,
};
}
self.rust_call_visitor().visit_terminator(&mut term);
self.reborrow_visitor().visit_terminator(&mut term);
tracing::debug!(term = ?term.kind);
term
Expand Down
139 changes: 4 additions & 135 deletions src/analyze/basic_block/visitor.rs
Original file line number Diff line number Diff line change
@@ -1,136 +1,5 @@
use rustc_middle::mir::{self, Local};
use rustc_middle::ty::{self as mir_ty, TyCtxt};
mod reborrow;
mod rust_call;

use crate::analyze::ReplacePlacesVisitor;

pub struct ReborrowVisitor<'a, 'tcx, 'ctx> {
tcx: TyCtxt<'tcx>,
analyzer: &'a mut super::Analyzer<'tcx, 'ctx>,
}

impl<'tcx> ReborrowVisitor<'_, 'tcx, '_> {
fn insert_borrow(&mut self, place: mir::Place<'tcx>, inner_ty: mir_ty::Ty<'tcx>) -> Local {
let r = mir_ty::Region::new_from_kind(self.tcx, mir_ty::RegionKind::ReErased);
let ty = mir_ty::Ty::new_mut_ref(self.tcx, r, inner_ty);
let decl = mir::LocalDecl::new(ty, Default::default()).immutable();
let new_local = self.analyzer.local_decls.push(decl);
let new_local_ty = self.analyzer.borrow_place_(place, inner_ty);
self.analyzer.bind_local(new_local, new_local_ty);
tracing::info!(old_place = ?place, ?new_local, "implicitly borrowed");
new_local
}

fn insert_reborrow(&mut self, place: mir::Place<'tcx>, inner_ty: mir_ty::Ty<'tcx>) -> Local {
let r = mir_ty::Region::new_from_kind(self.tcx, mir_ty::RegionKind::ReErased);
let ty = mir_ty::Ty::new_mut_ref(self.tcx, r, inner_ty);
let decl = mir::LocalDecl::new(ty, Default::default()).immutable();
let new_local = self.analyzer.local_decls.push(decl);
let new_local_ty = self.analyzer.borrow_place_(place, inner_ty);
self.analyzer.bind_local(new_local, new_local_ty);
tracing::info!(old_place = ?place, ?new_local, "implicitly reborrowed");
new_local
}
}

impl<'a, 'tcx, 'ctx> mir::visit::MutVisitor<'tcx> for ReborrowVisitor<'a, 'tcx, 'ctx> {
fn tcx(&self) -> TyCtxt<'tcx> {
self.tcx
}

fn visit_assign(
&mut self,
place: &mut mir::Place<'tcx>,
rvalue: &mut mir::Rvalue<'tcx>,
location: mir::Location,
) {
if !self.analyzer.is_defined(place.local) {
self.super_assign(place, rvalue, location);
return;
}

if place.projection.is_empty() && self.analyzer.is_mut_local(place.local) {
let ty = self.analyzer.local_decls[place.local].ty;
let new_local = self.insert_borrow(place.local.into(), ty);
let new_place = self.tcx.mk_place_deref(new_local.into());
ReplacePlacesVisitor::with_replacement(self.tcx, place.local.into(), new_place)
.visit_rvalue(rvalue, location);
*place = new_place;
self.super_assign(place, rvalue, location);
return;
}

let inner_place = if place.projection.last() == Some(&mir::PlaceElem::Deref) {
// *m = *m + 1 => m1 = &mut m; *m1 = *m + 1
let mut projection = place.projection.as_ref().to_vec();
projection.pop();
mir::Place {
local: place.local,
projection: self.tcx.mk_place_elems(&projection),
}
} else {
// s.0 = s.0 + 1 => m1 = &mut s.0; *m1 = *m1 + 1
*place
};

let ty = inner_place.ty(&self.analyzer.local_decls, self.tcx).ty;
let (new_local, new_place) = match ty.kind() {
mir_ty::TyKind::Ref(_, inner_ty, m) if m.is_mut() => {
let new_local = self.insert_reborrow(*place, *inner_ty);
(new_local, new_local.into())
}
mir_ty::TyKind::Adt(adt, args) if adt.is_box() => {
let inner_ty = args.type_at(0);
let new_local = self.insert_borrow(*place, inner_ty);
(new_local, new_local.into())
}
_ => {
let new_local = self.insert_borrow(*place, ty);
(new_local, self.tcx.mk_place_deref(new_local.into()))
}
};

ReplacePlacesVisitor::with_replacement(self.tcx, inner_place, new_place)
.visit_rvalue(rvalue, location);
*place = self.tcx.mk_place_deref(new_local.into());
self.super_assign(place, rvalue, location);
}

// TODO: is it always true that the operand is not referred again in rvalue
fn visit_operand(&mut self, operand: &mut mir::Operand<'tcx>, location: mir::Location) {
let Some(p) = operand.place() else {
self.super_operand(operand, location);
return;
};

let mir_ty::TyKind::Ref(_, inner_ty, m) =
p.ty(&self.analyzer.local_decls, self.tcx).ty.kind()
else {
self.super_operand(operand, location);
return;
};

if m.is_mut() {
let new_local = self.insert_reborrow(self.tcx.mk_place_deref(p), *inner_ty);
*operand = mir::Operand::Move(new_local.into());
}

self.super_operand(operand, location);
}
}

impl<'a, 'tcx, 'ctx> ReborrowVisitor<'a, 'tcx, 'ctx> {
pub fn new(analyzer: &'a mut super::Analyzer<'tcx, 'ctx>) -> Self {
let tcx = analyzer.tcx;
Self { analyzer, tcx }
}

pub fn visit_statement(&mut self, stmt: &mut mir::Statement<'tcx>) {
// dummy location
mir::visit::MutVisitor::visit_statement(self, stmt, mir::Location::START);
}

pub fn visit_terminator(&mut self, term: &mut mir::Terminator<'tcx>) {
// dummy location
mir::visit::MutVisitor::visit_terminator(self, term, mir::Location::START);
}
}
pub use reborrow::ReborrowVisitor;
pub use rust_call::RustCallVisitor;
Loading