Library Coq.Numbers.Cyclic.DoubleCyclic.DoubleMul


Set Implicit Arguments.

Require Import ZArith.
Require Import BigNumPrelude.
Require Import DoubleType.
Require Import DoubleBase.

Local Open Scope Z_scope.

Section DoubleMul.
 Variable w : Type.
 Variable w_0 : w.
 Variable w_1 : w.
 Variable w_WW : w -> w -> zn2z w.
 Variable w_W0 : w -> zn2z w.
 Variable w_0W : w -> zn2z w.
 Variable w_compare : w -> w -> comparison.
 Variable w_succ : w -> w.
 Variable w_add_c : w -> w -> carry w.
 Variable w_add : w -> w -> w.
 Variable w_sub: w -> w -> w.
 Variable w_mul_c : w -> w -> zn2z w.
 Variable w_mul : w -> w -> w.
 Variable w_square_c : w -> zn2z w.
 Variable ww_add_c : zn2z w -> zn2z w -> carry (zn2z w).
 Variable ww_add : zn2z w -> zn2z w -> zn2z w.
 Variable ww_add_carry : zn2z w -> zn2z w -> zn2z w.
 Variable ww_sub_c : zn2z w -> zn2z w -> carry (zn2z w).
 Variable ww_sub : zn2z w -> zn2z w -> zn2z w.



 Definition double_mul_c (cross:w->w->w->w->zn2z w -> zn2z w -> w*zn2z w) x y :=
  match x, y with
  | W0, _ => W0
  | _, W0 => W0
  | WW xh xl, WW yh yl =>
    let hh := w_mul_c xh yh in
    let ll := w_mul_c xl yl in
    let (wc,cc) := cross xh xl yh yl hh ll in
    match cc with
    | W0 => WW (ww_add hh (w_W0 wc)) ll
    | WW cch ccl =>
      match ww_add_c (w_W0 ccl) ll with
      | C0 l => WW (ww_add hh (w_WW wc cch)) l
      | C1 l => WW (ww_add_carry hh (w_WW wc cch)) l
      end
    end
  end.

 Definition ww_mul_c :=
  double_mul_c
    (fun xh xl yh yl hh ll=>
      match ww_add_c (w_mul_c xh yl) (w_mul_c xl yh) with
      | C0 cc => (w_0, cc)
      | C1 cc => (w_1, cc)
      end).

 Definition w_2 := w_add w_1 w_1.

 Definition kara_prod xh xl yh yl hh ll :=
    match ww_add_c hh ll with
      C0 m =>
         match w_compare xl xh with
           Eq => (w_0, m)
         | Lt =>
           match w_compare yl yh with
             Eq => (w_0, m)
           | Lt => (w_0, ww_sub m (w_mul_c (w_sub xh xl) (w_sub yh yl)))
           | Gt => match ww_add_c m (w_mul_c (w_sub xh xl) (w_sub yl yh)) with
                      C1 m1 => (w_1, m1) | C0 m1 => (w_0, m1)
                   end
           end
         | Gt =>
           match w_compare yl yh with
             Eq => (w_0, m)
           | Lt => match ww_add_c m (w_mul_c (w_sub xl xh) (w_sub yh yl)) with
                      C1 m1 => (w_1, m1) | C0 m1 => (w_0, m1)
                   end
           | Gt => (w_0, ww_sub m (w_mul_c (w_sub xl xh) (w_sub yl yh)))
           end
         end
    | C1 m =>
         match w_compare xl xh with
           Eq => (w_1, m)
         | Lt =>
           match w_compare yl yh with
             Eq => (w_1, m)
           | Lt => match ww_sub_c m (w_mul_c (w_sub xh xl) (w_sub yh yl)) with
                    C0 m1 => (w_1, m1) | C1 m1 => (w_0, m1)
                   end
           | Gt => match ww_add_c m (w_mul_c (w_sub xh xl) (w_sub yl yh)) with
                      C1 m1 => (w_2, m1) | C0 m1 => (w_1, m1)
                   end
           end
         | Gt =>
           match w_compare yl yh with
             Eq => (w_1, m)
           | Lt => match ww_add_c m (w_mul_c (w_sub xl xh) (w_sub yh yl)) with
                      C1 m1 => (w_2, m1) | C0 m1 => (w_1, m1)
                   end
           | Gt => match ww_sub_c m (w_mul_c (w_sub xl xh) (w_sub yl yh)) with
                     C1 m1 => (w_0, m1) | C0 m1 => (w_1, m1)
                   end
           end
         end
    end.

 Definition ww_karatsuba_c := double_mul_c kara_prod.

 Definition ww_mul x y :=
  match x, y with
  | W0, _ => W0
  | _, W0 => W0
  | WW xh xl, WW yh yl =>
    let ccl := w_add (w_mul xh yl) (w_mul xl yh) in
    ww_add (w_W0 ccl) (w_mul_c xl yl)
  end.

 Definition ww_square_c x :=
  match x with
  | W0 => W0
  | WW xh xl =>
    let hh := w_square_c xh in
    let ll := w_square_c xl in
    let xhxl := w_mul_c xh xl in
    let (wc,cc) :=
      match ww_add_c xhxl xhxl with
      | C0 cc => (w_0, cc)
      | C1 cc => (w_1, cc)
      end in
    match cc with
    | W0 => WW (ww_add hh (w_W0 wc)) ll
    | WW cch ccl =>
      match ww_add_c (w_W0 ccl) ll with
      | C0 l => WW (ww_add hh (w_WW wc cch)) l
      | C1 l => WW (ww_add_carry hh (w_WW wc cch)) l
      end
    end
  end.

 Section DoubleMulAddn1.
  Variable w_mul_add : w -> w -> w -> w * w.

  Fixpoint double_mul_add_n1 (n:nat) : word w n -> w -> w -> w * word w n :=
   match n return word w n -> w -> w -> w * word w n with
   | O => w_mul_add
   | S n1 =>
     let mul_add := double_mul_add_n1 n1 in
     fun x y r =>
     match x with
     | W0 => (w_0,extend w_0W n1 r)
     | WW xh xl =>
       let (rl,l) := mul_add xl y r in
       let (rh,h) := mul_add xh y rl in
       (rh, double_WW w_WW n1 h l)
     end
   end.

 End DoubleMulAddn1.

 Section DoubleMulAddmn1.
  Variable wn: Type.
  Variable extend_n : w -> wn.
  Variable wn_0W : wn -> zn2z wn.
  Variable wn_WW : wn -> wn -> zn2z wn.
  Variable w_mul_add_n1 : wn -> w -> w -> w*wn.
  Fixpoint double_mul_add_mn1 (m:nat) :
        word wn m -> w -> w -> w*word wn m :=
   match m return word wn m -> w -> w -> w*word wn m with
   | O => w_mul_add_n1
   | S m1 =>
     let mul_add := double_mul_add_mn1 m1 in
     fun x y r =>
     match x with
     | W0 => (w_0,extend wn_0W m1 (extend_n r))
     | WW xh xl =>
       let (rl,l) := mul_add xl y r in
       let (rh,h) := mul_add xh y rl in
       (rh, double_WW wn_WW m1 h l)
     end
   end.

 End DoubleMulAddmn1.

 Definition w_mul_add x y r :=
  match w_mul_c x y with
  | W0 => (w_0, r)
  | WW h l =>
    match w_add_c l r with
    | C0 lr => (h,lr)
    | C1 lr => (w_succ h, lr)
    end
  end.

  Variable w_digits : positive.
  Variable w_to_Z : w -> Z.

  Notation wB := (base w_digits).
  Notation wwB := (base (ww_digits w_digits)).
  Notation "[| x |]" := (w_to_Z x) (at level 0, x at level 99).
  Notation "[+| c |]" :=
   (interp_carry 1 wB w_to_Z c) (at level 0, c at level 99).
  Notation "[-| c |]" :=
   (interp_carry (-1) wB w_to_Z c) (at level 0, c at level 99).

  Notation "[[ x ]]" := (ww_to_Z w_digits w_to_Z x)(at level 0, x at level 99).
  Notation "[+[ c ]]" :=
   (interp_carry 1 wwB (ww_to_Z w_digits w_to_Z) c)
   (at level 0, c at level 99).
  Notation "[-[ c ]]" :=
   (interp_carry (-1) wwB (ww_to_Z w_digits w_to_Z) c)
   (at level 0, c at level 99).

  Notation "[|| x ||]" :=
    (zn2z_to_Z wwB (ww_to_Z w_digits w_to_Z) x) (at level 0, x at level 99).

  Notation "[! n | x !]" := (double_to_Z w_digits w_to_Z n x)
    (at level 0, x at level 99).

  Variable spec_more_than_1_digit: 1 < Zpos w_digits.
  Variable spec_w_0 : [|w_0|] = 0.
  Variable spec_w_1 : [|w_1|] = 1.

  Variable spec_to_Z : forall x, 0 <= [|x|] < wB.

  Variable spec_w_WW : forall h l, [[w_WW h l]] = [|h|] * wB + [|l|].
  Variable spec_w_W0 : forall h, [[w_W0 h]] = [|h|] * wB.
  Variable spec_w_0W : forall l, [[w_0W l]] = [|l|].
  Variable spec_w_compare :
     forall x y, w_compare x y = Z.compare [|x|] [|y|].
  Variable spec_w_succ : forall x, [|w_succ x|] = ([|x|] + 1) mod wB.
  Variable spec_w_add_c : forall x y, [+|w_add_c x y|] = [|x|] + [|y|].
  Variable spec_w_add : forall x y, [|w_add x y|] = ([|x|] + [|y|]) mod wB.
  Variable spec_w_sub : forall x y, [|w_sub x y|] = ([|x|] - [|y|]) mod wB.

  Variable spec_w_mul_c : forall x y, [[ w_mul_c x y ]] = [|x|] * [|y|].
  Variable spec_w_mul : forall x y, [|w_mul x y|] = ([|x|] * [|y|]) mod wB.
  Variable spec_w_square_c : forall x, [[ w_square_c x]] = [|x|] * [|x|].

  Variable spec_ww_add_c : forall x y, [+[ww_add_c x y]] = [[x]] + [[y]].
  Variable spec_ww_add : forall x y, [[ww_add x y]] = ([[x]] + [[y]]) mod wwB.
  Variable spec_ww_add_carry :
         forall x y, [[ww_add_carry x y]] = ([[x]] + [[y]] + 1) mod wwB.
  Variable spec_ww_sub : forall x y, [[ww_sub x y]] = ([[x]] - [[y]]) mod wwB.
  Variable spec_ww_sub_c : forall x y, [-[ww_sub_c x y]] = [[x]] - [[y]].

  Lemma spec_ww_to_Z : forall x, 0 <= [[x]] < wwB.

  Lemma spec_ww_to_Z_wBwB : forall x, 0 <= [[x]] < wB^2.

  Hint Resolve spec_ww_to_Z spec_ww_to_Z_wBwB : mult.
  Ltac zarith := auto with zarith mult.

  Lemma wBwB_lex: forall a b c d,
      a * wB^2 + [[b]] <= c * wB^2 + [[d]] ->
      a <= c.

  Lemma wBwB_lex_inv: forall a b c d,
      a < c ->
      a * wB^2 + [[b]] < c * wB^2 + [[d]].

  Lemma sum_mul_carry : forall xh xl yh yl wc cc,
   [|wc|]*wB^2 + [[cc]] = [|xh|] * [|yl|] + [|xl|] * [|yh|] ->
   0 <= [|wc|] <= 1.

  Theorem mult_add_ineq: forall xH yH crossH,
               0 <= [|xH|] * [|yH|] + [|crossH|] < wwB.

  Hint Resolve mult_add_ineq : mult.

  Lemma spec_mul_aux : forall xh xl yh yl wc (cc:zn2z w) hh ll,
   [[hh]] = [|xh|] * [|yh|] ->
   [[ll]] = [|xl|] * [|yl|] ->
   [|wc|]*wB^2 + [[cc]] = [|xh|] * [|yl|] + [|xl|] * [|yh|] ->
    [||match cc with
      | W0 => WW (ww_add hh (w_W0 wc)) ll
      | WW cch ccl =>
          match ww_add_c (w_W0 ccl) ll with
          | C0 l => WW (ww_add hh (w_WW wc cch)) l
          | C1 l => WW (ww_add_carry hh (w_WW wc cch)) l
          end
      end||] = ([|xh|] * wB + [|xl|]) * ([|yh|] * wB + [|yl|]).

  Lemma spec_double_mul_c : forall cross:w->w->w->w->zn2z w -> zn2z w -> w*zn2z w,
     (forall xh xl yh yl hh ll,
        [[hh]] = [|xh|]*[|yh|] ->
        [[ll]] = [|xl|]*[|yl|] ->
        let (wc,cc) := cross xh xl yh yl hh ll in
        [|wc|]*wwB + [[cc]] = [|xh|]*[|yl|] + [|xl|]*[|yh|]) ->
     forall x y, [||double_mul_c cross x y||] = [[x]] * [[y]].

  Lemma spec_ww_mul_c : forall x y, [||ww_mul_c x y||] = [[x]] * [[y]].

  Lemma spec_w_2: [|w_2|] = 2.

  Lemma kara_prod_aux : forall xh xl yh yl,
   xh*yh + xl*yl - (xh-xl)*(yh-yl) = xh*yl + xl*yh.

  Lemma spec_kara_prod : forall xh xl yh yl hh ll,
   [[hh]] = [|xh|]*[|yh|] ->
   [[ll]] = [|xl|]*[|yl|] ->
   let (wc,cc) := kara_prod xh xl yh yl hh ll in
   [|wc|]*wwB + [[cc]] = [|xh|]*[|yl|] + [|xl|]*[|yh|].
there is a carry in hh + ll

  Lemma sub_carry : forall xh xl yh yl z,
    0 <= z ->
    [|xh|]*[|yl|] + [|xl|]*[|yh|] = wwB + z ->
    z < wwB.

  Ltac Spec_ww_to_Z x :=
   let H:= fresh "H" in
   assert (H:= spec_ww_to_Z x).

  Ltac Zmult_lt_b x y :=
   let H := fresh "H" in
   assert (H := Zmult_lt_b _ _ _ (spec_to_Z x) (spec_to_Z y)).

  Lemma spec_ww_karatsuba_c : forall x y, [||ww_karatsuba_c x y||]=[[x]]*[[y]].

  Lemma spec_ww_mul : forall x y, [[ww_mul x y]] = [[x]]*[[y]] mod wwB.

  Lemma spec_ww_square_c : forall x, [||ww_square_c x||] = [[x]]*[[x]].

  Section DoubleMulAddn1Proof.

   Variable w_mul_add : w -> w -> w -> w * w.
   Variable spec_w_mul_add : forall x y r,
    let (h,l):= w_mul_add x y r in
    [|h|]*wB+[|l|] = [|x|]*[|y|] + [|r|].

   Lemma spec_double_mul_add_n1 : forall n x y r,
     let (h,l) := double_mul_add_n1 w_mul_add n x y r in
     [|h|]*double_wB w_digits n + [!n|l!] = [!n|x!]*[|y|]+[|r|].

  End DoubleMulAddn1Proof.

  Lemma spec_w_mul_add : forall x y r,
    let (h,l):= w_mul_add x y r in
    [|h|]*wB+[|l|] = [|x|]*[|y|] + [|r|].


End DoubleMul.