Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file not shown.
2 changes: 1 addition & 1 deletion src/Lean/Data/PersistentHashMap.lean
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ partial def findAtAux [BEq α] (keys : Array α) (vals : Array β) (heq : keys.s
else findAtAux keys vals heq (i+1) k
else none

partial def findAux [BEq α] : Node α β → USize → α → Option β
@[specialize] partial def findAux [BEq α] : Node α β → USize → α → Option β
| Node.entries entries, h, k =>
let j := (mod2Shift h shift).toNat
match entries[j]! with
Expand Down
155 changes: 124 additions & 31 deletions src/Lean/Environment.lean
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,54 @@ opaque EnvExtensionStateSpec : (α : Type) × Inhabited α := ⟨Unit, ⟨()⟩
@[expose] def EnvExtensionState : Type := EnvExtensionStateSpec.fst
instance : Inhabited EnvExtensionState := EnvExtensionStateSpec.snd

/--
Sparse copy-on-write container for environment extension states.
Stores an immutable base array (shared across forks) and a small persistent overlay
of modifications. Reads check the overlay first, falling back to the base.
-/
structure ExtensionStates where
/-- Immutable base state, typically set at import time. -/
base : Array EnvExtensionState
/-- Sparse overlay of modified extensions. -/
overlay : Lean.PersistentHashMap Nat EnvExtensionState := {}
deriving Inhabited

namespace ExtensionStates

@[inline] def size (s : ExtensionStates) : Nat := s.base.size

def get (s : ExtensionStates) (i : Nat) (h : i < s.size) : EnvExtensionState :=
s.overlay.find? i |>.getD s.base[i]

def get! (s : ExtensionStates) (i : Nat) : EnvExtensionState :=
s.overlay.find? i |>.getD s.base[i]!

@[inline] def set (s : ExtensionStates) (i : Nat) (v : EnvExtensionState) : ExtensionStates :=
{ s with overlay := s.overlay.insert i v }

@[inline] def modify (s : ExtensionStates) (i : Nat) (f : EnvExtensionState → EnvExtensionState) : ExtensionStates :=
{ s with overlay := s.overlay.insert i (f (s.get! i)) }

@[inline] def push (s : ExtensionStates) (v : EnvExtensionState) : ExtensionStates :=
{ s with base := s.base.push v }

/-- Set directly in the base array, bypassing the overlay. For use during import initialization. -/
@[inline] def setBase (s : ExtensionStates) (i : Nat) (v : EnvExtensionState) : ExtensionStates :=
{ s with base := s.base.set! i v }

/-- Modify directly in the base array, bypassing the overlay. For use during import initialization. -/
@[inline] def modifyBase (s : ExtensionStates) (i : Nat) (f : EnvExtensionState → EnvExtensionState) : ExtensionStates :=
{ s with base := s.base.modify i f }

/-- Create from a plain array (no overlay). -/
@[inline] def ofArray (arr : Array EnvExtensionState) : ExtensionStates :=
{ base := arr }

end ExtensionStates

instance : GetElem ExtensionStates Nat EnvExtensionState fun s i => i < s.size where
getElem s i h := s.get i h

@[expose] def ModuleIdx := Nat
deriving BEq, ToString, Hashable

Expand Down Expand Up @@ -259,12 +307,12 @@ structure Environment where
/--
Environment extensions. It also includes user-defined extensions.
-/
private extensions : Array EnvExtensionState
private extensions : ExtensionStates
/--
Additional imported environment extension state for the interpreter. Access via
`getModuleIREntries`.
-/
private irBaseExts : Array EnvExtensionState
private irBaseExts : ExtensionStates
/-- The header contains additional information that is set at import time. -/
header : EnvironmentHeader := private_decl% {}
deriving Nonempty
Expand Down Expand Up @@ -440,7 +488,7 @@ private structure AsyncConst where
Reported extension state eventually fulfilled by promise; may be missing for tasks (e.g. kernel
checking, synchronous decl addition) that can eagerly guarantee they will not report any state.
-/
exts? : Option (Task (Array EnvExtensionState))
exts? : Option (Task ExtensionStates)
/--
`Task AsyncConsts` except for problematic recursion. The set of nested constants created while
elaborating this constant.
Expand Down Expand Up @@ -571,7 +619,7 @@ structure Environment where
identical to `base.extensions` in other contexts. Access via
`getModuleEntries (level := .server)`.
-/
private serverBaseExts : Array EnvExtensionState := private_decl% base.private.extensions
private serverBaseExts : ExtensionStates := private_decl% base.private.extensions
/--
Kernel environment task that is fulfilled when all asynchronously elaborated declarations are
finished, containing the resulting environment. Also collects the environment extension state of
Expand Down Expand Up @@ -944,7 +992,7 @@ def PromiseCheckedResult.commitChecked (res : PromiseCheckedResult) (env : Envir
private structure ConstPromiseVal where
privateConstInfo : ConstantInfo
exportedConstInfo : ConstantInfo
exts : Array EnvExtensionState
exts : ExtensionStates
nestedConsts : VisibilityMap AsyncConsts
deriving Nonempty

Expand Down Expand Up @@ -1288,6 +1336,10 @@ structure EnvExtension (σ : Type) where private mk ::
present.
-/
replay? : Option (ReplayFn σ)
/-- When `false`, reads and writes bypass the overlay and go directly to the base array.
This is faster for frequently-accessed extensions but means modifications copy the
base array when the environment is shared (RC > 1). Default: `true`. -/
useOverlay : Bool
deriving Inhabited

namespace EnvExtension
Expand All @@ -1301,10 +1353,10 @@ private builtin_initialize envExtensionsRef : IO.Ref (Array (EnvExtension EnvExt
user-defined environment extensions. When this happens, we must adjust the size of the `env.extensions`.
This method is invoked when processing `import`s.
-/
partial def ensureExtensionsArraySize (exts : Array EnvExtensionState) : IO (Array EnvExtensionState) := do
partial def ensureExtensionsArraySize (exts : ExtensionStates) : IO ExtensionStates := do
loop exts.size exts
where
loop (i : Nat) (exts : Array EnvExtensionState) : IO (Array EnvExtensionState) := do
loop (i : Nat) (exts : ExtensionStates) : IO ExtensionStates := do
let envExtensions ← envExtensionsRef.get
if h : i < envExtensions.size then
let s ← envExtensions[i].mkInitial
Expand All @@ -1315,34 +1367,55 @@ where

private def invalidExtMsg := "invalid environment extension has been accessed"

private unsafe def setStateImpl {σ} (ext : EnvExtension σ) (exts : Array EnvExtensionState) (s : σ) : Array EnvExtensionState :=
if h : ext.idx < exts.size then
exts.set ext.idx (unsafeCast s)
private unsafe def setStateImpl {σ} (ext : EnvExtension σ) (exts : ExtensionStates) (s : σ) : ExtensionStates :=
if ext.idx < exts.size then
if ext.useOverlay then exts.set ext.idx (unsafeCast s)
else exts.setBase ext.idx (unsafeCast s)
else
panic! invalidExtMsg

private unsafe def modifyStateImpl {σ : Type} (ext : EnvExtension σ) (exts : ExtensionStates) (f : σ → σ) : ExtensionStates :=
if ext.idx < exts.size then
if ext.useOverlay then
exts.modify ext.idx fun s =>
let s : σ := unsafeCast s
let s : σ := f s
unsafeCast s
else
exts.modifyBase ext.idx fun s =>
let s : σ := unsafeCast s
let s : σ := f s
unsafeCast s
else
panic! invalidExtMsg

/-- Like `setStateImpl` but writes directly to the base array. For use during import initialization. -/
private unsafe def setStateBaseImpl {σ} (ext : EnvExtension σ) (exts : ExtensionStates) (s : σ) : ExtensionStates :=
if ext.idx < exts.size then
exts.setBase ext.idx (unsafeCast s)
else
-- do not return an empty array on panic, avoiding follow-up out-of-bounds accesses
have : Inhabited (Array EnvExtensionState) := ⟨exts⟩
panic! invalidExtMsg

private unsafe def modifyStateImpl {σ : Type} (ext : EnvExtension σ) (exts : Array EnvExtensionState) (f : σ → σ) : Array EnvExtensionState :=
/-- Like `modifyStateImpl` but writes directly to the base array. For use during import initialization. -/
private unsafe def modifyStateBaseImpl {σ : Type} (ext : EnvExtension σ) (exts : ExtensionStates) (f : σ → σ) : ExtensionStates :=
if ext.idx < exts.size then
exts.modify ext.idx fun s =>
exts.modifyBase ext.idx fun s =>
let s : σ := unsafeCast s
let s : σ := f s
unsafeCast s
else
-- do not return an empty array on panic, avoiding follow-up out-of-bounds accesses
have : Inhabited (Array EnvExtensionState) := ⟨exts⟩
panic! invalidExtMsg

private unsafe def getStateImpl {σ} [Inhabited σ] (ext : EnvExtension σ) (exts : Array EnvExtensionState) : σ :=
private unsafe def getStateImpl {σ} [Inhabited σ] (ext : EnvExtension σ) (exts : ExtensionStates) : σ :=
if h : ext.idx < exts.size then
unsafeCast exts[ext.idx]
if ext.useOverlay then unsafeCast exts[ext.idx]
else unsafeCast exts.base[ext.idx]
else
panic! invalidExtMsg

def mkInitialExtStates : IO (Array EnvExtensionState) := do
def mkInitialExtStates : IO ExtensionStates := do
let exts ← envExtensionsRef.get
exts.mapM fun ext => ext.mkInitial
return .ofArray (← exts.mapM fun ext => ext.mkInitial)

/--
Checks whether `modifyState (asyncDecl := declName)` may be called on an async environment
Expand Down Expand Up @@ -1405,6 +1478,15 @@ different environment branches are reconciled.
def setState {σ : Type} (ext : EnvExtension σ) (env : Environment) (s : σ) (asyncMode := ext.asyncMode) : Environment :=
inline <| modifyState (asyncMode := asyncMode) ext env fun _ => s

/--
Like `setState` but writes directly to the base array, bypassing the overlay.
For use during import initialization only (before any forking occurs).
-/
def setStateBase {σ : Type} (ext : EnvExtension σ) (env : Environment) (s : σ) : Environment :=
-- safety: `ext`'s constructor is private, so we can assume the entry at `ext.idx` is of type `σ`
env.modifyCheckedAsync fun env =>
{ env with extensions := unsafe ext.setStateBaseImpl env.extensions s }

-- `unsafe` fails to infer `Nonempty` here
private unsafe def getStateUnsafe {σ : Type} [Inhabited σ] (ext : EnvExtension σ)
(env : Environment) (asyncMode := ext.asyncMode) (asyncDecl : Name := .anonymous) : σ := Id.run do
Expand Down Expand Up @@ -1482,17 +1564,18 @@ end EnvExtension
For that, you need to register a persistent environment extension. -/
def registerEnvExtension {σ : Type} (mkInitial : IO σ)
(replay? : Option (ReplayFn σ) := none)
(asyncMode : EnvExtension.AsyncMode := .mainOnly) : IO (EnvExtension σ) := do
(asyncMode : EnvExtension.AsyncMode := .mainOnly)
(useOverlay : Bool := true) : IO (EnvExtension σ) := do
unless (← initializing) do
throw (IO.userError "failed to register environment, extensions can only be registered during initialization")
let exts ← EnvExtension.envExtensionsRef.get
let idx := exts.size
let ext : EnvExtension σ := { idx, mkInitial, asyncMode, replay? }
let ext : EnvExtension σ := { idx, mkInitial, asyncMode, replay?, useOverlay }
-- safety: `EnvExtensionState` is opaque, so we can upcast to it
EnvExtension.envExtensionsRef.modify fun exts => exts.push (unsafe unsafeCast ext)
pure ext

private def mkInitialExtensionStates : IO (Array EnvExtensionState) := EnvExtension.mkInitialExtStates
private def mkInitialExtensionStates : IO ExtensionStates := EnvExtension.mkInitialExtStates

@[export lean_mk_empty_environment]
def mkEmptyEnvironment (trustLevel : UInt32 := 0) : IO Environment := do
Expand Down Expand Up @@ -1656,6 +1739,10 @@ end PersistentEnvExtension

builtin_initialize persistentEnvExtensionsRef : IO.Ref (Array (PersistentEnvExtension EnvExtensionEntry EnvExtensionEntry EnvExtensionState)) ← IO.mkRef #[]

/-- Hook called after `finalizePersistentExtensions` during `importModules`.
Used to initialize centralized scope stack states. -/
builtin_initialize postFinalizePersistentExtensionsHookRef : IO.Ref (Environment → IO Environment) ← IO.mkRef pure

-- Helper structure to enable cyclic default values of `exportEntriesFn` and `exportEntriesFnEx`.
structure PersistentEnvExtensionDescrCore (α β σ : Type) where
name : Name := by exact decl_name%
Expand All @@ -1666,6 +1753,7 @@ structure PersistentEnvExtensionDescrCore (α β σ : Type) where
statsFn : σ → Format := fun _ => Format.nil
asyncMode : EnvExtension.AsyncMode := .mainOnly
replay? : Option (ReplayFn σ) := none
useOverlay : Bool := true

attribute [inherit_doc PersistentEnvExtension.exportEntriesFn]
PersistentEnvExtensionDescrCore.exportEntriesFnEx
Expand All @@ -1692,7 +1780,7 @@ unsafe def registerPersistentEnvExtensionUnsafe {α β σ : Type} [Inhabited σ]
if pExts.any (fun ext => ext.name == descr.name) then throw (IO.userError s!"invalid environment extension, '{descr.name}' has already been used")
let replay? := descr.replay?.map fun replay =>
fun oldState newState newConsts s => { s with state := replay oldState.state newState.state newConsts s.state }
let ext ← registerEnvExtension (asyncMode := descr.asyncMode) (replay? := replay?) do
let ext ← registerEnvExtension (asyncMode := descr.asyncMode) (replay? := replay?) (useOverlay := descr.useOverlay) do
let initial ← descr.mkInitial
let s : PersistentEnvExtensionState α σ := {
importedEntries := #[],
Expand Down Expand Up @@ -1845,23 +1933,23 @@ def mkExtNameMap (startingAt : Nat) : IO (Std.HashMap Name Nat) := do
result := result.insert descr.name i
return result

private def setImportedEntries (states : Array EnvExtensionState) (mods : Array ModuleData)
(startingAt : Nat := 0) : IO (Array EnvExtensionState) := do
private def setImportedEntries (states : ExtensionStates) (mods : Array ModuleData)
(startingAt : Nat := 0) : IO ExtensionStates := do
let mut states := states
let extDescrs ← persistentEnvExtensionsRef.get
/- For extensions starting at `startingAt`, ensure their `importedEntries` array have size `mods.size`. -/
for extDescr in extDescrs[startingAt...*] do
-- safety: as in `modifyState`
states := unsafe extDescr.toEnvExtension.modifyStateImpl states fun s =>
-- safety: as in `modifyState`; write to base since this is import-time initialization
states := unsafe extDescr.toEnvExtension.modifyStateBaseImpl states fun s =>
{ s with importedEntries := .replicate mods.size #[] }
/- For each module `mod`, and `mod.entries`, if the extension name is one of the extensions after `startingAt`, set `entries` -/
let extNameIdx ← mkExtNameMap startingAt
for h : modIdx in *...mods.size do
let mod := mods[modIdx]
for (extName, entries) in mod.entries do
if let some entryIdx := extNameIdx[extName]? then
-- safety: as in `modifyState`
states := unsafe extDescrs[entryIdx]!.toEnvExtension.modifyStateImpl states fun s =>
-- safety: as in `modifyState`; write to base since this is import-time initialization
states := unsafe extDescrs[entryIdx]!.toEnvExtension.modifyStateBaseImpl states fun s =>
{ s with importedEntries := s.importedEntries.set! modIdx entries }
return states

Expand Down Expand Up @@ -1903,7 +1991,7 @@ where
let prevSize := (← persistentEnvExtensionsRef.get).size
let prevAttrSize ← getNumBuiltinAttributes
let newState ← extDescr.addImportedFn s.importedEntries { env := env, opts := opts }
let mut env := extDescr.toEnvExtension.setState (asyncMode := .sync) env { s with state := newState }
let mut env := extDescr.toEnvExtension.setStateBase env { s with state := newState }
if extDescr.name == `Lean.regularInitAttr then
-- Run `[init]` attributes now. We do this after `setState` so `runInitAttrs` can access
-- `getModule(IR)Entries` but we should also do it before attempting to run user-defined
Expand Down Expand Up @@ -2290,6 +2378,9 @@ def finalizeImport (s : ImportState) (imports : Array Import) (opts : Options) (
env ← unsafe Runtime.markPersistent env
if loadExts then
env ← finalizePersistentExtensions env moduleData opts
-- Run post-finalization hooks (e.g., initialize centralized scope stack)
let hook ← postFinalizePersistentExtensionsHookRef.get
env ← hook env
if leakEnv then
/- Ensure the final environment including environment extension states is
marked persistent as documented.
Expand Down Expand Up @@ -2392,6 +2483,8 @@ def displayStats (env : Environment) : IO Unit := do
IO.println ("number of buckets for imported consts: " ++ toString env.constants.numBuckets);
IO.println ("trust level: " ++ toString env.header.trustLevel);
IO.println ("number of extensions: " ++ toString env.base.private.extensions.size);
let overlayCount := env.base.private.extensions.overlay.foldl (fun n _ _ => n + 1) 0
IO.println ("extensions in overlay: " ++ toString overlayCount);
pExtDescrs.forM fun extDescr => do
IO.println ("extension '" ++ toString extDescr.name ++ "'")
-- get state from `checked` at the end if `async`; it would otherwise panic
Expand Down
1 change: 1 addition & 0 deletions src/Lean/ReducibilityAttrs.lean
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ builtin_initialize reducibilityCoreExt : PersistentEnvExtension (Name × Reducib
statsFn := fun s => "reducibility attribute core extension" ++ Format.line ++ "number of local entries: " ++ format s.size
-- attribute is set by `addPreDefinitions`
asyncMode := .async .asyncEnv
useOverlay := false
replay? := some <| fun _oldState newState newItems otherState =>
newItems.foldl (init := otherState) fun otherState k =>
if let some v := newState.find? k then
Expand Down
Loading