(*
  Higher-Order Abstract Syntax in Type Theory
  Venanzio Capretta
  Foundation Group Seminar, Nijmegen, 24 October 2006
*)

Require Import HOUA.  (* Import Higher-Order Universal Algebra package. *)
Open Scope hoas_scope.

Set Implicit Arguments.

(* Definition of simple types. *)
Inductive type: Set :=
  o: type
| arr: type -> type -> type.

Infix "-->" := arr (at level 28, right associativity).

(* Equality of types is decidable. *)
Lemma type_dec: forall t1 t2:type, {t1=t2}+{t1<>t2}.
Proof.
intros t1; elim t1; [idtac | intros t11 IH1 t12 IH2];
  intro t2; case t2;
  try (left;auto;fail);
  try (right;discriminate).
intros t21 t22; case (IH1 t21); intro H1; case (IH2 t22); intro H2;
  try (left; rewrite H1; rewrite H2; trivial; fail);
  right; intro HH; injection HH; auto.
Defined.

(* Definition of operations. *)
Inductive op: Set :=
  lambda: type -> type -> op
| app: type -> type -> op.

(* Specification of operations. *)
Definition rules: op -> operation_type type:=
fun f:op =>
match f with
 lambda t1 t2 => [[t1] |-- t2] // t1-->t2
| app t1 t2    => [[] |-- t1-->t2 , [] |-- t1] // t2
end.

Definition lambda_sign: Signature := signature type type_dec op rules.

Definition LTerm:= Term lambda_sign.

(* The body of a function f:A->B is the term representing \x.(f x). *)
Definition fun_body: forall A B:type, (LTerm A -> LTerm B) -> LTerm B :=
  fun A B F => bind lambda_sign ([A]|--B) (fun X => F (sl_head X)).

(* Notation *)

Definition Var : forall t: type, nat -> LTerm t
               := fun t i => var lambda_sign t i.
Definition Bnd : forall t: type, nat -> LTerm t
               := fun t i => bnd lambda_sign t i.

Definition Lambda_fun (t1 t2:type):
  (LTerm t1 -> LTerm t2) -> LTerm (t1-->t2) :=
Opr_curry lambda_sign (lambda t1 t2).

Notation "'Lambda' x 'dot' b" :=
  (Lambda_fun (fun x => b))
  (at level 30, right associativity).

Definition App (t1 t2:type):
  LTerm (t1-->t2) -> LTerm t1 -> LTerm t2 :=
Opr_curry lambda_sign (app t1 t2).

Infix "@" := App (at level 29, left associativity).

(* Examples of terms *)

Definition Nat := (o-->o)-->o-->o.

Definition zero: LTerm Nat :=
  Lambda f dot
  Lambda x dot x.

Definition one: LTerm Nat :=
  Lambda f dot
  Lambda x dot
    f @ x.

Definition two: LTerm Nat :=
  Lambda f dot
  Lambda x dot
    f @ (f @ x).


Definition succ: LTerm (Nat-->Nat) :=
  Lambda n dot
  Lambda f dot
  Lambda x dot
    f @ (n @ f @ x).

Definition add: LTerm (Nat-->Nat-->Nat) :=
  Lambda n dot
  Lambda m dot
  Lambda f dot
  Lambda x dot
    n @ f @ (m @ f @ x).

Definition mult: LTerm (Nat-->Nat--> Nat) :=
  Lambda n dot
  Lambda m dot
  Lambda f dot
  Lambda x dot
    m @ (Lambda y dot (n @ f @ y)) @ x.

Definition Bool := o-->o-->o.

Definition fls: LTerm Bool :=
  Lambda x dot
  Lambda y dot
    x.

Definition tru: LTerm Bool :=
  Lambda x dot
  Lambda y dot
    y.

Definition is_zero: LTerm (Nat-->Bool) :=
  Lambda n dot
  Lambda a dot
  Lambda b dot
    n @ (tru @ b) @ a.

Definition if_then_else:LTerm (Bool-->Nat-->Nat-->Nat) :=
  Lambda b dot
  Lambda n dot
  Lambda m dot
  Lambda f dot
  Lambda x dot
    b @ (n @ f @ x) @ (m @ f @ x).









(* ... *)

Lemma Opr_Lambda_id:
  forall (A B:type)(gs:meta_args lambda_sign [[A]|--B]),
    Opr lambda_sign (lambda A B) gs
    = Lambda x in A to (am_head gs) (sort_cons A x (sort_nil _)).
Proof.
intros A B gs.
simpl.
apply Opr_extensional with (sign:=lambda_sign)(f:=lambda A B)(hs1:=gs).
split; simpl.
intro al; unfold am_head; simpl.
rewrite <- sl_nil_id with (al:=(sl_tail al)).
rewrite <- sl_head_tail_id with (al:=al); trivial.
trivial.
Qed.

Lemma Opr_App_id:
  forall (A B:type)(gs:meta_args lambda_sign [[]|--(arr A B), []|--A]),
    Opr lambda_sign (app A B) gs
    = (am_head gs (sort_nil _)) @ (am_head (am_tail gs) (sort_nil _)).
Proof.
intros A B gs; simpl.
apply Opr_extensional with (sign:=lambda_sign)(f:=app A B)(hs1:=gs).
simpl.
repeat split; intro al; rewrite sl_nil_id with (al:=al); trivial.
Qed.



Theorem LTerm_induction:
  forall (P: forall A:type, LTerm A -> Prop),
    (forall (A:type)(i:nat), P A (Var A i)) ->
    (forall (A:type)(i:nat), P A (Bnd A i)) ->
    (forall (A B:type)(F:LTerm A -> LTerm B),
            P B (fun_body F) -> P (arr A B) (Lambda x in A to F x)) ->
    (forall (A B:type)(f:LTerm (arr A B))(a:LTerm A),
            P (arr A B) f -> P A a -> P B (f @ a)) ->
  forall (A:type)(t:LTerm A), P A t.
Proof.
intros P Hvar Hbnd Hlambda Happ.
unfold LTerm in P; unfold LTerm; intros A t; apply term_induction; 
  clear t; clear A; auto.
intro f; elim f; auto.
(* abstraction *)
intros A B g Hg.
simpl; rewrite Opr_Lambda_id.
apply Hlambda with
  (F:= fun x => am_head g (sort_cons A x (sort_nil (Term lambda_sign)))).
simpl in Hg.
simpl in g.
rewrite am_head_tail_id with (hs:=g) in Hg.
simpl in Hg.
unfold fun_body; simpl.
rewrite bind_extensional with (h2:=am_head g).
exact (proj1 Hg).
intro al; rewrite <- sl_nil_id with (al:=(sl_tail al)).
simpl in al.
rewrite sl_head_tail_id with (al:=al); trivial.
(* application *)
simpl; intros A B gs IH.
rewrite Opr_App_id.
rewrite am_head_tail_id with (hs:=gs) in IH; simpl in IH.
generalize IH; clear IH; intros [IH1 IH2].
apply Happ.
generalize (bind_no_args_id lambda_sign (arr A B) (am_head gs)); intro H.
rewrite <- H.
exact IH1.
red in gs.
simpl in IH2.
rewrite am_head_tail_id with
  (hs:=am_tail (T:=[]|--arr A B) (Ts:=[[]|--A]) gs) in IH2.
simpl in IH2.
generalize IH2; clear IH2; intros [IH2 _].
generalize (bind_no_args_id lambda_sign A (am_head (am_tail gs))); intro H2.
rewrite <- H2.
exact IH2.
Qed.

Inductive eval: forall t:type, LTerm t -> LTerm t -> Prop :=
  eval_Lam: forall t1 t2:type, forall e: LTerm t1 -> LTerm t2,
            eval (Lambda x in t1 to e x) (Lambda x in t1 to e x)
| eval_App: forall t1 t2:type,
            forall e1:LTerm (t1-->t2), forall e2: LTerm t1,
            forall e: LTerm t1 -> LTerm t2, forall v: LTerm t2,
            eval e1 (Lambda x in t1 to e x) ->
            eval (e e2) v ->
            eval (e1 @ e2) v.

Lemma eval_unicity :
  forall t:type, forall e v1:LTerm t,
  (eval e v1) -> forall v2:LTerm t, (eval e v2) -> v1=v2.
Proof.
intros t e v1; induction 1.
intros v2 H.
inversion H.
 (* inversion now works, but I haven't tried to go further with the proof *)
unfold tFun.
apply mk_term_unicity.
rewrite -> H1.
reflexivity.
intros v2 H2.
inversion H2.
apply IHeval2.
specialize term_unicity with (1:=H3); intro H9.
specialize term_unicity with (1:=H4); intro H10.
subst.
clear H3 H4.
generalize (IHeval1 (tFun e4) H5); intro H3.
generalize (Fun_inj e e4 H0 H6 H3 e2); intro H4.
rewrite -> H4; auto.
Qed.
