(****************************************************************
   Michel St-Martin and Amy Felty
   December 2015, Coq version V8.4
                                                                 
   A Verified Algorithm for Detecting Conflicts in XACML Access
   Control Rules, Michel St-Martin and Amy P. Felty, In Proceedings of
   the Fifth ACM SIGPLAN Conference on Certified Programs and Proofs
   (CPP), January 2016.

  ***************************************************************)

Require Export Tactics.

Require Export Arith.
Require Export List.
Require Import Coq.Program.Subset.
Open Scope Z_scope. 

Infix "<=b" := Zle_bool (at level 70): Z_scope.
Infix "<b" := Zlt_bool (at level 70): Z_scope.
Notation "A /\b B" := (andb A B) (at level 80, right associativity).
Notation "A \/b B" := (orb A B) (at level 85, right associativity).

(*seconds in a day = 86400*)
Definition MIDNIGHT := 86400.

Definition Z' := {z:Z | 0 <= z <= MIDNIGHT}.
Definition inj_Z'_Z : Z' -> Z := @proj1_sig _ _.
Coercion inj_Z'_Z : Z' >-> Z.

Definition Z'NM := {z:Z | 0 <= z < MIDNIGHT}.
Definition inj_Z'NM_Z : Z'NM -> Z := @proj1_sig _ _.
Coercion inj_Z'NM_Z : Z'NM >-> Z.

Lemma Zero_rc : 0 <= 0 <= MIDNIGHT.
Proof.
unfold MIDNIGHT; omega.
Qed.

Lemma ZeroNM_rc : 0 <= 0 < MIDNIGHT.
Proof.
unfold MIDNIGHT; omega.
Qed.

Lemma Midnight_rc : 0 <= MIDNIGHT <= MIDNIGHT.
Proof.
unfold MIDNIGHT; omega.
Qed.

(* Time Abbreviations *)
Definition Z'0:Z' := (exist _ 0 Zero_rc).
Definition Z'NM0:Z'NM := (exist _ 0 ZeroNM_rc).
Definition Z'MIDNIGHT:Z' := (exist _ MIDNIGHT Midnight_rc).

(* The types used for requests *)
Inductive reqValue: Set := 
| timeReq : Z'NM -> reqValue
| intReq : Z -> reqValue
| blank : reqValue.

Inductive core : Set :=
| any : reqValue -> core
| empty : reqValue -> core
| timeInRange : Z' -> Z' -> core
| intInRange : Z -> Z -> core
| intGt : Z ->  core
| intLt : Z -> core
| na : core.

Inductive trilean : Set :=
| T : trilean
| F : trilean
| NA : trilean.

Definition trileanNot (t : trilean) : trilean :=
match t with
| T => F
| F => T
| NA => NA
end.

Definition trileanAnd (t1 t2 : trilean) : trilean :=
match t1 with
| T => t2
| F =>
  match t2 with
  | NA => NA
  | _ => F
  end
| NA => NA
end.

Notation "~t A" := (trileanNot A) (at level 75, right associativity).
Notation "A /\t B" := (trileanAnd A B) (at level 80, right associativity).

Definition invalidArgs (c: core) : bool :=
match c with
| any blank => true
| empty blank => true
| timeInRange m M => M <=b m
| intInRange m M => M <b m
| na => true
| _ => false
end.

(* tests if a value (in a request) applies to a core *)
Definition coreMatch (reqv:reqValue) (c: core) : trilean :=
if invalidArgs c then NA else
match reqv with
|  timeReq req =>
   match c with
   | any (timeReq r) => T
   | empty (timeReq r) => F
   | timeInRange m M => if m <=b req /\b req <b M then T else F
   | _ => NA
   end

|  intReq req =>
   match c with
   | any (intReq r) => T
   | empty (intReq r) => F
   | intInRange m M => if m <=b req /\b req <=b M then T else F
   | intGt m => if m <b req then T else F
   | intLt M => if req <b M then T else F
   | _ => NA 
   end
| blank => NA
end.

(*Fixpoint isEmpty (c: core) : Prop :=
match c with
| empty n => True
| timeInRange m M => m = M
| _ => False
end.

Definition isEmptyb (c: core) : bool :=
match c with
| empty n => true
| timeInRange m M => Zeq_bool m M
| intInRange m M => Zeq_bool m M
| _ => false
end.

Definition isInvalid (c: core) : bool :=
match c with
| timeInRange m M => M <b m
| intInRange m M => M <b m
| _ => false
end.

Fixpoint isListEmpty (ls:list core) {struct ls}  : Prop :=
match ls with
| nil => True
| c::ls' => (isEmpty c) /\ (isListEmpty ls')
end.*)

Definition min (z1 z2:Z) : Z := if z1 <=b z2 then z1 else z2.
Definition max (z1 z2:Z) : Z := if z1 <=b z2 then z2 else z1.

Definition min' (z1 z2:Z') : Z' := if z1 <=b z2 then z1 else z2.
Definition max' (z1 z2:Z') : Z' := if z1 <=b z2 then z2 else z1.

Definition type (c : core) : reqValue :=
match c with
| any (timeReq r) | empty (timeReq r) => timeReq Z'NM0
| timeInRange m1 M1 => timeReq Z'NM0
| any (intReq r) | empty (intReq r) => intReq 0
| intInRange m1 M1 => intReq 0
| intGt m1 | intLt m1 => intReq 0
| any blank | empty blank => blank
| na => blank
end.

Definition typeDiff (r1 r2 : reqValue) : bool :=
match r1, r2 with
| timeReq _ , timeReq _ => false
| intReq _ , intReq _ => false
| blank, blank => false
| _, _ => true
end.

(* Checks if 2 core intersect *)
Definition coreCheck (c1 c2: core) : core  := 
if typeDiff (type c1) (type c2) then na else 
if invalidArgs c1 then na else
if invalidArgs c2 then na else
match c1 with
| any _ => c2
| empty _ => c1

| timeInRange m1 M1 =>
  match c2 with 
  | any (timeReq z2) => c1
  | empty (timeReq z2) => empty (timeReq Z'NM0)
  | timeInRange m2 M2 => if m1 <b M2 /\b m2 <b M1 then timeInRange (max' m1 m2) (min' M1 M2) else empty (timeReq Z'NM0)
  | _ => na
  end

| intInRange m1 M1 => 
  match c2 with 
  | any (intReq z2) =>  c1
  | empty (intReq z2) => empty (intReq 0)
  | intInRange m2 M2 => if m1 <=b M2 /\b m2 <=b M1 then intInRange (max m1 m2) (min M1 M2) else empty (intReq 0)
  | intGt m2 => if m2 <b M1 then intInRange (max m1 (m2+1)) M1 else empty (intReq 0)
  | intLt M2 => if m1 <b M2 then intInRange m1 (min M1 (M2-1)) else empty (intReq 0)
  | _ => na
  end

| intGt m1 => 
  match c2 with 
  | any (intReq z2) => c1
  | empty (intReq z2) => empty (intReq 0)
  | intInRange m2 M2 => if m1 <b M2 then intInRange (max (m1+1) m2) M2 else empty (intReq 0)
  | intGt m2 => intGt (max m1 m2)
  | intLt M2 => if m1 <b (M2-1) then intInRange (m1+1) (M2-1) else empty (intReq 0)
  | _ => na
  end

| intLt M1 => 
  match c2 with 
  | any (intReq z2) => c1
  | empty (intReq z2) => empty (intReq 0)
  | intInRange m2 M2 => if m2 <b M1 then intInRange m2 (min (M1-1) M2) else empty (intReq 0)
  | intGt m2 => if m2 <b (M1-1) then intInRange (m2+1) (M1-1) else empty (intReq 0)
  | intLt M2 => intLt (min M1 M2)
  | _ => na
end
| _ => na
end.

(* SOUNDNESS AND COMPLETENESS OF THE ALGORITHM *)
(* Soundness *)
(*Ltac simplFF :=
cut ((false \/b false) = false); trivial; intro temp; rewrite temp in *; clear temp.*)

Lemma ZeqBoolIsEq: forall (n m : Z), (Zeq_bool n m) = true -> n = m.
Proof.
intros.
generalize (Zeq_is_eq_bool n m); intro.
destruct H0.
apply (H1 H).
Qed.

Lemma ZeqIsEqBool: forall (n m : Z), n = m -> (Zeq_bool n m) = true.
Proof.
intros.
generalize (Zeq_is_eq_bool n m); intro.
destruct H0.
apply (H0 H).
Qed.

Lemma ZTri: forall (n m : Z), (n < m) \/ (m < n)  \/ (n = m).
Proof.
intros.
generalize (Ztrichotomy n m); intros H1.
decompose [or] H1; clear H1; [tauto | tauto | auto with *].
Qed.

Lemma coercionBijective: forall (z1 z2: Z'), inj_Z'_Z z1 = inj_Z'_Z z2 -> z1 = z2.
Proof.
intros.
unfold inj_Z'_Z in H.
apply subset_eq.
trivial.
Qed.

Ltac simplCoercion z1 z2 H := generalize (coercionBijective z1 z2 H); intro temp; subst z1.

Lemma maxSym : forall (n m : Z), max n m = max m n.
Proof.
intros.
unfold max.
generalize (ZTri m n); intro.
decompose [or] H; [| | subst m; trivial];
test (m <=b n); test (n <=b m); 
trivial.
Qed.

Lemma max'Sym : forall (n m : Z'), max' n m = max' m n.
Proof.
intros.
unfold max'.
generalize (ZTri m n); intro.
decompose [or] H; [| |simplCoercion m n H1; trivial];
test (m <=b n); test (n <=b m); 
trivial.
Qed.

Lemma minSym : forall (n m : Z), min n m = min m n.
Proof.
intros.
unfold min.
generalize (ZTri m n); intro.
decompose [or] H; [| | subst m; trivial];
test (m <=b n); test (n <=b m); 
trivial.
Qed.

Lemma min'Sym : forall (n m : Z'), min' n m = min' m n.
Proof.
intros.
unfold min'.
generalize (ZTri m n); intro.
decompose [or] H; [| | simplCoercion m n H1; trivial];
test (m <=b n); test (n <=b m); 
trivial.
Qed.

Ltac swapMax m n := generalize (maxSym m n); intro temp; rewrite temp; clear temp.
Ltac swapMax' m n := generalize (max'Sym m n); intro temp; rewrite temp; clear temp.
Ltac swapMin m n := generalize (minSym m n); intro temp; rewrite temp; clear temp.
Ltac swapMin' m n := generalize (min'Sym m n); intro temp; rewrite temp; clear temp.

(*Lemma intRangeCheckSymetric: forall (c1 c2: core) (m1 m2 M1 M2 : Z), 
c1 = intInRange m1 M1  -> c2 = intInRange m2 M2  ->
coreCheck c1 c2 = coreCheck c2 c1.
Proof.
intros.
subst c1; subst c2.
unfold coreCheck.
simpl.
repeat findIf.
swapMax m1 m2; swapMin M1 M2; trivial.
Qed.

Lemma timeRangeCheckSymetric: forall (c c1 c2: core) (m1 m2 M1 M2 : Z'), 
c1 = timeInRange m1 M1  -> c2 = timeInRange m2 M2  ->
coreCheck c1 c2 = coreCheck c2 c1.
Proof.
intros.
subst c1; subst c2.
unfold coreCheck.
simpl.
repeat findIf.
swapMax' m1 m2; swapMin' M1 M2; trivial.
Qed.

Hint Resolve timeRangeCheckSymetric intRangeCheckSymetric maxSym minSym.

Lemma coreCheckSymetric: forall (c1 c2: core), 
coreCheck c1 c2 = coreCheck c2 c1.
Proof.
intros.
remember c1; induction c1;
remember c2; induction c2;
rewrite Heqc, Heqc0 in *;
try eauto;
unfold coreCheck;
simpl;
try destruct r;
try destruct r0;
simpl; trivial.
swapMax z (z1+1); trivial.
swapMin (z1-1) z0; trivial.
swapMax z0 (z+1); trivial.
swapMax z0 z; trivial.
swapMin z1 (z-1); trivial.
swapMin z0 z; trivial.
Qed.*)

Ltac define s def :=
set (temp := def);
remember temp as s;
unfold temp in *; clear temp.

(*Definition type2 (reqv : reqValue) : reqValue :=
match reqv with
| timeReq r => timeReq Z'0
| intReq r => intReq 0
| blank => blank
end.

Lemma typeVsNa : forall (c : core) (reqv : reqValue),
type c <> type2 reqv -> coreMatch reqv *)

Lemma emptyCorrect : forall (c2 : core) (r1 reqv : reqValue), 
coreMatch reqv (coreCheck (empty r1) c2) = ((coreMatch reqv (empty r1)) /\t (coreMatch reqv c2)).
Proof.
intros.
unfold coreCheck, coreMatch.

destruct reqv; destruct r1;
trivial;
destruct c2; try destruct r; simpl in *; trivial;
try discriminate C;
unfold coreMatch;
repeat findIf.
Qed.

Lemma anyCorrect: forall  (c2 : core) (r1 reqv : reqValue) , 
coreMatch reqv (coreCheck (any r1) c2) = ((coreMatch reqv (any r1)) /\t (coreMatch reqv c2)).
Proof.
intros.
unfold coreCheck, coreMatch.
destruct reqv; destruct r1;
trivial;
destruct c2; try destruct r; simpl; trivial;
repeat findIf.
Qed.

Lemma rangeCorrect: forall (c2 : core) (reqv : reqValue) (m1 M1 : Z') , 
coreMatch reqv (coreCheck (timeInRange m1 M1) c2) = ((coreMatch reqv (timeInRange m1 M1)) /\t (coreMatch reqv c2)).
Proof.
intros.
unfold coreCheck, coreMatch.
destruct reqv; trivial;
destruct c2; try destruct r; simpl; trivial;
try solve [repeat findIf].

unfold max', min'.
repeat (findIf; b2PH); omega.
Qed.

Lemma intRangeCorrect: forall (c2 : core) (reqv : reqValue) (m1 M1 : Z) , 
coreMatch reqv (coreCheck (intInRange m1 M1) c2) = ((coreMatch reqv (intInRange m1 M1)) /\t (coreMatch reqv c2)).
Proof.
intros.
unfold coreCheck, coreMatch.
unfold max, min.
destruct reqv; trivial;
remember c2; induction c; try destruct r; simpl; trivial;
try solve [repeat (findIf;b2PH); omega].
Qed.

Lemma trileanAndSym : forall (t1 t2 :trilean), (t1 /\t t2) = (t2 /\t t1).
Proof.
intros.
destruct t1; destruct t2; trivial.
Qed.

Lemma intGtCorrect: forall (c2 : core) (reqv : reqValue) (m: Z) , 
coreMatch reqv (coreCheck (intGt m) c2) = 
((coreMatch reqv (intGt m)) /\t (coreMatch reqv c2)).
Proof.
intros.
unfold coreCheck, coreMatch.
unfold max, min.
destruct reqv; trivial;
remember c2; induction c; try destruct r; simpl; trivial;
try solve [repeat (findIf;b2PH); omega].
Qed.

Lemma intLtCorrect: forall (c2 : core) (reqv : reqValue) (M: Z) , 
coreMatch reqv (coreCheck (intLt M) c2) = 
((coreMatch reqv (intLt M)) /\t (coreMatch reqv c2)).
Proof.
intros.
unfold coreCheck, coreMatch.
unfold max, min.
destruct reqv; trivial;
remember c2; induction c; try destruct r; simpl; trivial;
try solve [repeat (findIf;b2PH); omega].
Qed.

Lemma naCorrect: forall (c2 : core) (reqv : reqValue), 
coreMatch reqv (coreCheck na c2) = 
((coreMatch reqv na) /\t (coreMatch reqv c2)).
Proof.
intros.
destruct reqv; trivial;
unfold coreCheck, invalidArgs;
findIf.
Qed.

Lemma andComm: forall (p q: Prop), p /\ q -> q /\ p.
Proof.
tauto.
Qed.

(*Lemma coreSym : forall (c1 c2 : core) (reqv : reqValue), 
coreMatch reqv (coreCheck c1 c2) = coreMatch reqv (coreCheck c2 c1).
Proof.
intros.
generalize (coreCheckSymetric c1 c2); intro.
rewrite H; trivial.
Qed.*)

Hint Resolve anyCorrect emptyCorrect andComm rangeCorrect intRangeCorrect intGtCorrect intLtCorrect naCorrect.

Lemma coreCheckCorrect: forall (c1 c2 : core) (reqv : reqValue),
coreMatch reqv (coreCheck c1 c2) = ((coreMatch reqv c1) /\t (coreMatch reqv c2)).
Proof.
intros.
destruct c1; eauto.
Qed.

Fixpoint nonEmptyCore (c :core) : trilean :=
if invalidArgs c then NA else
match c with 
| empty _ => F
| na => NA
| _ => T
end.

Lemma nonEmpty' : forall c : core, nonEmptyCore c = T -> exists rv : reqValue, coreMatch rv c = T.
Proof.
intros.
unfold coreMatch.
induction c; simpl in *; try auto with *.
exists r.
destruct r;
unfold coreMatch;
simpl; trivial.
destruct r; discriminate H.

caseSimpl (z0 <=b z).
discriminate H.

b2PH; assert (0 <= z <= MIDNIGHT).
destruct z; trivial.
assert (0 <= z0 <= MIDNIGHT).
destruct z0; trivial.
assert (0 <= z < MIDNIGHT); [omega|].
clear H0 H1.
exists (timeReq (exist _ (inj_Z'_Z z) H2)).
simpl in *.
test (z <=b z).
b2PH; assert (z < z0).
omega.
test (z <b z0).
simpl.
trivial.

exists (intReq z).
caseSimpl (z0 <b z).
discriminate H.
test (z <=b z).
b2PH; assert (z <= z0).
omega.
test (z <=b z0).
simpl.
trivial.

exists (intReq (z+1)).
assert (z < z + 1); [omega|].
test (z <b z + 1).
trivial.

exists (intReq (z-1)).
assert (z -1 < z); [omega|].
test (z - 1 <b z).
trivial.
discriminate H.
Qed.

Lemma nonEmpty : forall (c1 c2 : core), nonEmptyCore (coreCheck c1 c2) = T ->
exists rv : reqValue, coreMatch rv (coreCheck c1 c2) = T.
Proof.
intros.
apply (nonEmpty' (coreCheck c1 c2) H).
Qed.

Lemma coreCheckNotNa : forall (c1 c2 : core), invalidArgs c1 = false ->
invalidArgs c2 = false -> typeDiff (type c1) (type c2) = false -> coreCheck c1 c2 <> na.
Proof.
intros.
unfold coreCheck.
rewrite H, H0, H1.
destruct c1; try discriminate;
destruct c2; try discriminate;
simpl in *; try (destruct r; discriminate).
destruct (z <b z2 /\b z1 <b z0); discriminate.
destruct (z <=b z2 /\b z1 <=b z0); discriminate.
destruct (z1 <b z0); discriminate.
destruct (z <b z1); discriminate.
destruct (z <b z1); discriminate.
destruct (z <b z0 - 1); discriminate.
destruct (z0 <b z); discriminate.
destruct (z0 <b z - 1); discriminate.
Qed.