Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 113 additions & 92 deletions ITree/ITree.lean
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
/-
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author(s): Max Nowak
-/

import Qpf
import Mathlib.Data.QPF.Multivariate.Constructions.Sigma

Expand All @@ -27,12 +21,12 @@ Thank you to Alex Keizer (github.com/alexkeizer) for the help in figuring this o
-/

-- Intuitively: `ρ` is `ITree E A`, and `ν` is `(Ans : Type) × (E Ans) × (Ans → ITree E A)`.
inductive Shape (E : Type -> Type) (A : Type 1) (ρ : Type 1) (ν : Type 1) : Type 1
inductive Shape (E : Type u -> Type u) (A : Type (u+1)) (ρ : Type (u+1)) (ν : Type (u+1)) : Type (u+1)
| ret (r : A) : Shape E A ρ ν
| tau (t : ρ) : Shape E A ρ ν
| vis (e : ν) : Shape E A ρ ν

abbrev Shape.Uncurried (E : Type -> Type) : TypeFun 3 := TypeFun.ofCurried (Shape E)
abbrev Shape.Uncurried (E : Type u -> Type u) : TypeFun 3 := TypeFun.ofCurried (Shape E)

instance : MvFunctor (Shape.Uncurried E) where
map f
Expand All @@ -44,57 +38,58 @@ instance : MvFunctor (Shape.Uncurried E) where
## Step 2: Functors for constructors
-/

qpf G_ret (E : Type -> Type) A ρ := A
qpf G_tau (E : Type -> Type) A ρ := ρ
qpf G_ret (E : Type u -> Type u) A ρ := A
qpf G_tau (E : Type u -> Type u) A ρ := ρ
-- qpf G_vis (E : Type -> Type) A ρ := (Ans : Type) × E Ans × (Ans → ρ) -- this unfortunately doesn't work, hence the workaround below

section SigmaWorkaround
/-- `qpf G (Ans : Type) (E : Type → Type) A ρ ν := E Ans × (Ans → ρ)`, but universe-polymorphic -/
abbrev G.Uncurried (Ans : Type 0) (E : Type 0 → Type 0) : TypeFun.{1, 1} 2 :=
abbrev G.Uncurried (Ans : Type u) (E : Type u → Type u) : TypeFun.{u+1, u+1} 2 :=
MvQPF.Comp (n := 2) (m := 2) -- compose two 2-ary (A, ρ) functors `E Ans` and `Ans -> ρ`
(TypeFun.ofCurried Prod.{1, 1}) -- ✓ MvQPF
(TypeFun.ofCurried Prod.{u+1, u+1}) -- ✓ MvQPF
![
MvQPF.Const 2 (ULift (E Ans)), -- ✓ inst₁
MvQPF.Const 2 (ULift.{u+1} (E Ans)), -- ✓ inst₁
MvQPF.Comp (n := 1) (m := 2) -- ✓ inst₂
(TypeFun.ofCurried (MvQPF.Arrow (ULift.{1} Ans)))
(TypeFun.ofCurried (MvQPF.Arrow (ULift.{u+1} Ans)))
![MvQPF.Prj (n := 2) 0]
]

instance inst₁ {Ans : Type} {E : Type -> Type} : MvQPF.{1, 1} (MvQPF.Const 2 (ULift (E Ans))) := inferInstance
instance inst₁ {Ans : Type u} {E : Type u -> Type u} : MvQPF.{u+1, u+1} (MvQPF.Const 2 (ULift.{u+1} (E Ans))) := inferInstance

instance : MvQPF.{1,1} (MvQPF.Prj (n := 2) 0) := inferInstance
instance : MvQPF.{1,1} (![MvQPF.Prj (n := 2) 0] 0) := MvQPF.Prj.mvqpf 0
instance : ∀i, MvQPF.{1,1} (![MvQPF.Prj (n := 2) 0] i) := fun | 0 => inferInstance
instance inst₂ {Ans : Type} : MvQPF.{1, 1} (MvQPF.Comp (n := 1) (m := 2)
(TypeFun.ofCurried (MvQPF.Arrow (ULift.{1} Ans)))
instance : MvQPF.{u+1,u+1} (MvQPF.Prj (n := 2) 0) := inferInstance
instance : MvQPF.{u+1,u+1} (![MvQPF.Prj (n := 2) 0] 0) := MvQPF.Prj.mvqpf 0
instance : ∀i, MvQPF.{u+1,u+1} (![MvQPF.Prj (n := 2) 0] i) := fun | 0 => inferInstance
instance inst₂ {Ans : Type u} : MvQPF.{u+1, u+1} (MvQPF.Comp (n := 1) (m := 2)
(TypeFun.ofCurried (MvQPF.Arrow (ULift.{u+1} Ans)))
![MvQPF.Prj (n := 2) 0]
) := inferInstance

abbrev G_vis.Uncurried (E : Type → Type) : TypeFun.{1,1} 2 :=
MvQPF.Sigma.{1} fun (Ans : Type) => (G.Uncurried Ans E)
set_option pp.universes true
abbrev G_vis.Uncurried (E : Type u → Type u) : TypeFun.{u+1,u+1} 2 :=
MvQPF.Sigma.{_} (A := Type (u)) fun (Ans : Type u) => (G.Uncurried Ans E)

def G_vis (E : Type → Type) (A : Type) (ρ : Type 1) : Type 1 := G_vis.Uncurried E ![ULift A, ρ]
abbrev G_vis (E : Type u → Type u) (A : Type (u+1)) (ρ : Type (u+1)) : Type (u+1) := G_vis.Uncurried E ![A, ρ]

-- #synth MvQPF (TypeFun.ofCurried (n := 2) @Prod.{1}) -- :)
#synth MvQPF (TypeFun.ofCurried (n := 2) @Prod) -- :)

instance inst₃ {Ans} {E : Type -> Type} : ∀i, MvQPF.{1, 1} (
instance inst₃ {Ans : Type u} {E : Type u -> Type u} : ∀i, MvQPF.{u+1, u+1} (
![
MvQPF.Const 2 (ULift (E Ans)), -- ✓ inst₁
MvQPF.Const 2 (ULift.{u+1} (E Ans)), -- ✓ inst₁
MvQPF.Comp (n := 1) (m := 2) -- ✓ inst₂
(TypeFun.ofCurried (MvQPF.Arrow (ULift.{1} Ans)))
(TypeFun.ofCurried (MvQPF.Arrow (ULift.{u+1} Ans)))
![MvQPF.Prj (n := 2) 0]
] i
) := fun
| 0 => inst₁
| 1 => inst₂

instance {Ans} {E : Type -> Type} : MvQPF.{1, 1} (G.Uncurried Ans E) := inferInstance
instance {Ans} {E : Type u -> Type u} : MvQPF.{u+1, u+1} (G.Uncurried Ans E) := inferInstance

-- #synth MvQPF.{1, 1} (G.Uncurried ?Ans ?E) -- :)
-- #synth MvQPF.{1, 1} (G_vis.Uncurried ?E) -- :)
#synth MvQPF.{_, _} (G.Uncurried ?Ans ?E) -- :)
#synth MvQPF.{_, _} (G_vis.Uncurried ?E) -- :)
end SigmaWorkaround

abbrev ConstructorFs (E : Type -> Type) : Fin2 3 -> TypeVec 2 -> Type 1 :=
abbrev ConstructorFs (E : Type u -> Type u) : Fin2 3 -> TypeVec 2 -> Type (u+1) :=
![G_ret.Uncurried E, G_tau.Uncurried E, G_vis.Uncurried E]

/- For some reason tc synthesis can't figure these out by itself
Expand All @@ -104,7 +99,7 @@ abbrev ConstructorFs (E : Type -> Type) : Fin2 3 -> TypeVec 2 -> Type 1 :=
instance [inst : MvQPF (G_vis.Uncurried E)] : MvQPF (ConstructorFs E 0) := inst
instance [inst : MvQPF (G_tau.Uncurried E)] : MvQPF (ConstructorFs E 1) := inst
instance [inst : MvQPF (G_ret.Uncurried E)] : MvQPF (ConstructorFs E 2) := inst
instance {E : Type -> Type} : (i : Fin2 3) -> MvQPF (ConstructorFs E i) :=
instance {E : Type u -> Type u} : (i : Fin2 3) -> MvQPF (ConstructorFs E i) :=
fun
| 0 => inferInstance
| 1 => inferInstance
Expand All @@ -116,14 +111,15 @@ instance {E : Type -> Type} : (i : Fin2 3) -> MvQPF (ConstructorFs E i) :=
`Base E : (A : Type) -> (ρ : Type 1) -> Type 1`.
-/

abbrev Base.Uncurried (E : Type -> Type) : TypeFun 2 :=
abbrev Base.Uncurried (E : Type u -> Type u) : TypeFun.{u+1} 2 :=
MvQPF.Comp
(TypeFun.ofCurried (n := 3) (Shape E))
![G_ret.Uncurried E, G_tau.Uncurried E, G_vis.Uncurried E]

def Base (E : Type -> Type) (A : Type) (ρ : Type 1) : Type 1 := Base.Uncurried E ![ULift A, ρ]
abbrev Base (E : Type u -> Type u) (A : Type (u+1)) (ρ : Type (u+1)) : Type (u+1) :=
Base.Uncurried E ![ULift.{u+1} A, ρ]

instance {E : Type -> Type} : MvFunctor (Base.Uncurried E) where
instance {E : Type u -> Type u} : MvFunctor.{u+1, u+1} (Base.Uncurried E) where
map f x := MvQPF.Comp.map f x

/-
Expand All @@ -136,20 +132,20 @@ instance {E : Type -> Type} : MvFunctor (Base.Uncurried E) where
which preserve MvQPF.
-/

inductive HeadT : Type 1
inductive HeadT : Type u
| ret
| tau
| vis

def ChildT : HeadT -> TypeVec.{1} 3
abbrev ChildT : HeadT -> TypeVec.{u+1} 3
| .ret => ![PFin2 1, PFin2 0, PFin2 0] -- One `A`, zero `ρ`, zero `ν`
| .tau => ![PFin2 0, PFin2 1, PFin2 0] -- Zero `A`, one `ρ`, zero `ν` (remember, `ρ` intuitively means `ITree E A`)
| .vis => ![PFin2 0, PFin2 0, PFin2 1] -- Zero `A`, zero ρ, one ν (where ν is our `(Ans : Type) × ...`)

private def P : MvPFunctor.{1} 3 := ⟨HeadT, ChildT⟩
private abbrev F (E : Type -> Type) : TypeVec.{1} 3 -> Type 1 := Shape.Uncurried E
abbrev P : MvPFunctor.{u+1} 3 := ⟨HeadT, ChildT⟩
abbrev F (E : Type u -> Type u) : TypeVec.{u+1} 3 -> Type (u+1) := Shape.Uncurried E

private def box (E : Type -> Type) : F E α → P.Obj α
def box (E : Type u -> Type u) : F E α → P.Obj α
| .ret (a : α 2) => Sigma.mk HeadT.ret fun -- `a : A`
| 2 => fun (_ : PFin2 1) => a
| 1 => PFin2.elim0
Expand All @@ -163,12 +159,12 @@ private def box (E : Type -> Type) : F E α → P.Obj α
| 1 => PFin2.elim0
| 0 => fun (_ : PFin2 1) => e

private def unbox (E : Type -> Type) : P.Obj α → F E α
def unbox (E : Type u -> Type u) : P.Obj α → F E α
| ⟨.ret, child⟩ => Shape.ret (child 2 .fz)
| ⟨.tau, child⟩ => Shape.tau (child 1 .fz)
| ⟨.vis, child⟩ => Shape.vis (child 0 .fz)

private theorem box_unbox_id (x : P.Obj α) : box E (unbox E x) = x := by
theorem box_unbox_id (x : P.Obj α) : box E (unbox E x) = x := by
rcases x with ⟨head, child⟩
cases head <;> (
rw [unbox, box]
Expand All @@ -177,7 +173,7 @@ private theorem box_unbox_id (x : P.Obj α) : box E (unbox E x) = x := by
rfl
)

private theorem unbox_box_id (x : F E α) : unbox E (box E x) = x := by cases x <;> rfl
theorem unbox_box_id (x : F E α) : unbox E (box E x) = x := by cases x <;> rfl

instance Shape.instMvQPF : MvQPF (F E) := MvQPF.ofPolynomial P (box E) (unbox E) box_unbox_id unbox_box_id (by
intro α β f x
Expand All @@ -197,69 +193,94 @@ instance Base.instMvQPF : MvQPF (Base.Uncurried E) := inferInstance
- Define our (co-)eliminator, `cases`, etc.
-/

def Uncurried (E : Type -> Type) := MvQPF.Cofix (Base.Uncurried E)
def _root_.ITree (E : Type -> Type) (A : Type) : Type 1 := Uncurried E ![ULift A]
set_option pp.universes true
-- abbrev Uncurried (E : Type u -> Type u) : TypeVec.{u + 1} 1 -> Type (u + 1) := MvQPF.Cofix (Base.Uncurried E)
abbrev Uncurried (E : Type _ -> Type _) : TypeVec 1 -> Type _ := MvQPF.Cofix (Base.Uncurried E)

abbrev _root_.PITree.{u, v} (E : Type (max u v) -> Type (max u v)) (A : Sort v) : Type (max u v + 1) :=
Uncurried E ![ULift.{max u v + 1} <| PLift.{v} A]

/-- Just like ITree, but universe-polymorphic, allowing you to write `PITree _ P` where `P : Prop`. -/
abbrev _root_.PITree'.{u, v} (E : Type (max u v) -> Type (max u v)) (A : Sort v) : Type (max (u+1) v) :=
Uncurried E ![ULift.{max (u+1) v} <| PLift.{v} A]

abbrev _root_.ITree (E : Type _ -> Type _) (A : Type _) : Type _ := PITree E A

def ret {E : Type -> Type} {A : Type} (a : A) : ITree E A := MvQPF.Cofix.mk (Shape.ret (.up a))
def tau {E : Type -> Type} {A : Type} (t : ITree E A) : ITree E A := MvQPF.Cofix.mk (Shape.tau t)
def vis {E : Type -> Type} {A : Type} {Ans : Type} (e : E Ans) (k : Ans -> ITree E A) : ITree E A :=
MvQPF.Cofix.mk (Shape.vis ⟨Ans, .up e, fun ans => k ans.down⟩)

def corec {E : Type -> Type} {A : Type} {β : Type 1} (f : β → Base E A β) (b : β) : ITree E A
:= MvQPF.Cofix.corec (n := 1) (F := Base.Uncurried E) f b

def dest {E : Type -> Type} {A : Type} : ITree E A -> Base E A (ITree E A)
:= MvQPF.Cofix.dest
abbrev ret.{u, v} {E : Type (max u v) -> Type (max u v)} {A : Sort v} (a : A) : ITree E A
:= MvQPF.Cofix.mk (Shape.ret (.up (.up a)))
abbrev tau.{u, v} {E : Type (max u v) -> Type (max u v)} {A : Sort v} (t : ITree E A) : ITree E A
:= MvQPF.Cofix.mk (Shape.tau t)
abbrev vis.{u, v} {E : Type (max u v) -> Type (max u v)} {A : Sort v} {Ans : Type (max u v)}
(e : E Ans)
(k : Ans -> ITree E A)
: ITree E A :=
MvQPF.Cofix.mk (Shape.vis ⟨Ans, .up e, fun ans => k ans.down⟩)

def corec' {E : Type (max u v) -> Type (max u v)} {A : Sort v} {β : Type (max u v + 1)}
-- (f : (β → Base E (ULift.{max u v + 1} <| PLift.{v} A) β : Type (max u v + 1)))
(f : β → Base.Uncurried E ![ULift.{max u v + 1} <| PLift.{v} A, β])
(b : β)
: ITree.{u, v} E A
:= MvQPF.Cofix.corec (n := 1) (F := Base.Uncurried E) f b


-- def corec {E : Type (max u v) -> Type (max u v)} {A : Sort v} {β : Type (max u v + 1)}
-- (f : (β → Base E (ULift.{max u v + 1} <| PLift.{v} A) β : Type (max u v + 1)))
-- -- (f : β → Base.Uncurried E ![ULift.{max u v + 1} <| PLift.{v} A, β])
-- (b : β)
-- : ITree.{u, v} E A
-- := corec.internal f b

-- def dest {E : Type (max u v) -> Type (max u v)} {A : Sort _} : ITree E A -> Base E A (ITree E A)
-- := MvQPF.Cofix.dest.{u, v}

@[cases_eliminator, elab_as_elim]
def cases {E : Type -> Type} {A : Type} {motive : ITree E A -> Sort u}
(ret : (r : A) → motive (ret r))
(tau : (x : ITree E A) → motive (tau x))
(vis : {Ans : Type} -> (e : E Ans) → (k : Ans → ITree E A) → motive (vis e k))
(x : ITree E A) : motive x :=
match h : MvQPF.Cofix.dest x with
| .ret (.up r) =>
have h : x = ITree.ret r := by
apply_fun MvQPF.Cofix.mk at h
simpa [MvQPF.Cofix.mk_dest] using h
h ▸ ret r
| .tau y =>
have h : x = ITree.tau y := by
apply_fun MvQPF.Cofix.mk at h
simpa [MvQPF.Cofix.mk_dest] using h
h ▸ tau y
| .vis ⟨Ans, .up e, k⟩ =>
have h : x = ITree.vis e (fun ans => k (.up ans)) := by
apply_fun MvQPF.Cofix.mk at h
simpa [MvQPF.Cofix.mk_dest] using h
h ▸ vis e (fun ans => k (.up ans))
def cases {E : Type (max u v) -> Type (max u v)} {A : Sort v} {motive : ITree.{u, v} E A -> Sort w}
(ret : (r : A) → motive (ITree.ret.{u, v} r))
(tau : (x : ITree.{u, v} E A) → motive (tau.{u, v} x))
(vis : {Ans : Type _} -> (e : E Ans) → (k : Ans → ITree E A) → motive (vis.{u, v} e k))
(x : ITree.{u, v} E A) : motive x
:= sorry
-- match h : MvQPF.Cofix.dest x with
-- | .ret (.up r) =>
-- have h : x = ITree.ret r := by
-- apply_fun MvQPF.Cofix.mk at h
-- simpa [MvQPF.Cofix.mk_dest] using h
-- h ▸ ret r
-- | .tau y =>
-- have h : x = ITree.tau y := by
-- apply_fun MvQPF.Cofix.mk at h
-- simpa [MvQPF.Cofix.mk_dest] using h
-- h ▸ tau y
-- | .vis ⟨Ans, .up e, k⟩ =>
-- have h : x = ITree.vis e (fun ans => k (.up ans)) := by
-- apply_fun MvQPF.Cofix.mk at h
-- simpa [MvQPF.Cofix.mk_dest] using h
-- h ▸ vis e (fun ans => k (.up ans))

#exit

-- Computation rules
theorem cases_ret : cases (motive := motive) c_ret c_tau c_vis (.ret r) = c_ret r := rfl
theorem cases_tau : cases (motive := motive) c_ret c_tau c_vis (.tau t) = c_tau t := sorry
theorem cases_vis : cases (motive := motive) c_ret c_tau c_vis (.vis e k) = c_vis e k := sorry

-- Without these being irreducible, some declarations in other files get a whnf/isDefEq timeout:
attribute [irreducible] Uncurried ITree ret tau vis corec dest cases

/-
# Some convenience stuff
# Some common stuff
-/

-- This implementation is extremely brittle. It is deceptively simple, but moving things around just
-- a little bit can break it, likely because of the order in which metavars get assigned.
def Base.map (f : X -> Y) : Base E A X -> Base E A Y :=
fun (bX : Base.Uncurried E (![ULift A] ::: X)) =>
let arrow : TypeVec.Arrow ((![ULift A] : TypeVec 1) ::: X) (![ULift A] ::: Y) := TypeVec.appendFun (n := 1)
(α := ![ULift A]) (α' := ![ULift A])
TypeVec.id
f
let bY := MvFunctor.map (n := 2) (F := Uncurried E) arrow bX
bY
def spin : ITree E A := corec (fun .unit => .tau .unit) PUnit.unit


/-- Just a convenience function. Re-plays a tree within another tree. -/
def Base.replay (ta : ITree E A₁) (fTree : ITree E A₁ -> C) (fRet : A₁ -> A₂ := by exact id) : ITree.Base E A₂ C :=
match ta.dest with
| .ret (.up a : _) => .ret (.up (fRet a))
| .tau (t : ITree E A₁) => .tau (fTree t)
| .vis ⟨Ans, e, k⟩ => .vis ⟨Ans, e, (fun x => fTree (k x))⟩
def Base.replay (ta : ITree E A₁) (fTree : ITree E A₁ -> C) (fRet : A₁ -> A₂ := by exact id) : ITree.Base E A₂ C := sorry
-- match ta.dest with
-- | .ret (.up a : _) => .ret (.up (fRet a))
-- | .tau (t : ITree E A₁) => .tau (fTree t)
-- | .vis ⟨_, e, k⟩ => .vis e (fun x => fTree (k x))

def Base.Map (f : C -> D) : TypeVec.Arrow (TypeVec.ofList [C, B, E]) (TypeVec.ofList [D, B, E])
:= TypeVec.appendFun TypeVec.id f