In this short, code-heavy post, we extend some of the work from a previous post to reason about the cardinalities of sums and products.

Last time we saw how to reason about the cardinalities of sets using the Fin type. The main technical tool was a axiomless replacement for dependent destruction. This time we'll prove a slightly more interesting result about the cardinality of a particular inductive type.

First we pull in the definition of Fin, and, just like last time, we define cardinality in terms of a bijection to Fin.t n.

Require Import Fin.

Definition cardinality (A : Type) (n : nat) : Prop :=
  exists (to_fin    : A -> Fin.t n)
         (from_fin  : Fin.t n -> A),
    (forall x, from_fin (to_fin x) = x) /\
    (forall y, to_fin (from_fin y) = y).

The goal of the post will be to prove a result about the cardinality of the type T, given below.

Inductive T (A : Type) : Type :=
| T1 : T A
| T2 : A -> T A
| T3 : A -> A -> T A.

There is an obvious relationship between the cardinality of A and the cardinality of T A.

Theorem cardinality_T :
  forall A n,
    cardinality A n ->
    cardinality (T A) (1 + n + n * n).

Our strategy will be to prove lemmas about the cardinalities of sums and products. We'll then express T as a sum of products and apply our lemmas.

We start with an easy lemma showing that bijections preserve cardinality.

Lemma card_bij :
  forall A B (f : A -> B) g n,
    (forall x, f (g x) = x) ->
    (forall x, g (f x) = x) ->
    cardinality A n ->
    cardinality B n.
Proof.
  unfold cardinality.
  firstorder.
  exists (fun z => x (g z)), (fun z => f (x0 z)).
  intuition congruence.
Qed.

We then define conversion functions between T and an encoding in terms of sum and product types...


Definition T_to_sum {A : Type} (x : T A) : (unit + A) + A * A :=
  match x with
    | T1 => inl (inl tt)
    | T2 a => inl (inr a)
    | T3 a b => inr (a,b)
  end.

Definition sum_to_T {A : Type} (x : (unit + A) + A * A) : T A :=
  match x with
    | inl l => match l with
                 | inl _ => T1 A
                 | inr a => T2 A a
               end
    | inr (a,b) => T3 A a b
  end.

...and prove the conversions are inverses.

Lemma to_from_T :
  forall A (x : T A),
    sum_to_T (T_to_sum x) = x.
Proof.
  destruct x; auto.
Qed.

Require Import JRWTactics.

Lemma from_to_T :
  forall A (x : unit + A + A * A),
    T_to_sum (sum_to_T x) = x.
Proof.
  intros.
  unfold T_to_sum, sum_to_T.
  repeat break_match; auto; try congruence.
  destruct u.
  auto.
Qed.

The only trick with the second proof is that we use the custom tactic break_match which finds a match statement in the context and calls destruct on its discriminee. You can get my tactic library on GitHub.

Using this bijection and the lemma above, we can prove that the cardinality of T A is the same as the cardinality of unit + A + A * A.

Lemma card_T_sum_prod :
  forall A n,
    cardinality (unit + A + A * A) n ->
    cardinality (T A) n.
Proof.
  eauto using card_bij, to_from_T, from_to_T.
Qed.

With this lemma in our back pocket, all that's left is to analyze the cardinality of unit + A + A * A. We'll prove lemmas about the cardinality of unit, sums, and products, and then put them all together in the final theorem.

We start by showing that unit has cardinality 1.

Lemma card_unit :
  cardinality unit 1.
Proof.
  unfold cardinality.
  exists (fun _ => F1), (fun _ => tt).
  intuition.
  - destruct x. auto.
  - (* hover here to see subgoal *)

Here we want to do case analysis on y to prove that it must be F1. For this we'll need some of the dependent case analysis work from last time. fin_case implements a "destruction principle" for Fin.t. It also now comes with a new and improved proof: Look Ma, no decidable equality!

Definition fin_case n x :
  forall (P : Fin.t (S n) -> Type),
    P F1 ->
    (forall y, P (FS y)) ->
    P x :=
  match x as x0 in Fin.t n0
     return
       forall P,
         match n0 as n0' return (t n0' -> (t n0' -> Type) -> Type) with
           | 0 => fun _ _ => False
           | S m => fun x P => P F1 -> (forall x0, P (FS x0)) -> P x
         end x0 P
  with
    | F1 _ => fun _ H1 _ => H1
    | FS _ _ => fun _ _ HS => HS _
  end.

We wrap up fin_case in a tactic that makes it easy to case analysis on a variable.

Ltac fin_dep_destruct v :=
  pattern v; apply fin_case; clear v; intros.

We can now finish the proof about the cardinality of unit using our new fin_dep_destruct tactic, as well as solve_by_inversion which searches for a contradiction in the context (in this case, an inhabitant of Fin.t 0).

  fin_dep_destruct y.
  + auto.
  + solve_by_inversion.
Qed.

We now prove that the cardinality of a sum of types is the sum of the cardinalities. (Say that 5 times fast.) We first define a conversion function from A + B to Fin.t (n + m) given conversions from A to Fin.t n and B to Fin.t m.

Definition from_sum {n m A B} (f : A -> Fin.t n) (g : B -> Fin.t m)
           (x : A + B) : Fin.t (n + m) :=
  match x with
    | inl a => L _ (f a)
    | inr b => R _ (g b)
  end.

This was easy to define because of the built in functions L and R, which inject Fin.t n and Fin.t m into Fin.t (n + m), respectively.

To go the other direction, we first define a sort of inverse of L and R, which, given a Fin.t (n + m), returns either a Fin.t n or a Fin.t m.

Fixpoint fin_of_sum_to_sum_of_fin {n m} {struct n} :
     Fin.t (n + m) -> Fin.t n + Fin.t m :=
  match n as n' return Fin.t (n' + m) -> Fin.t n' + Fin.t m with
    | 0 => fun x : t m => inr x
    | S n' => fun x : t (S (n' + m)) =>
                fin_case _ x _
                         (inl F1)
                         (fun x' : t (n' + m) =>
                            match fin_of_sum_to_sum_of_fin x' with
                              | inl a => inl (FS a)
                              | inr b => inr b
                            end)
  end.

Here we use fin_case as a programming construct, in contrast to our previous use in proofs. Our fancy new "proof" (ie, implementation) of fin_case will make reasoning about this function possible. The key property of fin_of_sum_to_sum_of_fin is that it correctly distinguishes elements in the range of L from those in the range of R. We show these two properties next.

Lemma fin_of_sum_to_sum_of_fin_L :
  forall m (x : Fin.t m) n,
    fin_of_sum_to_sum_of_fin (L n x) = inl x.
Proof.
  induction m; intros; simpl in *.
  - solve_by_inversion.
  - fin_dep_destruct x.
    + auto.
    + simpl.
      find_higher_order_rewrite.
      auto.
Qed.

Lemma fin_of_sum_to_sum_of_fin_R :
  forall n m (x : Fin.t m),
    fin_of_sum_to_sum_of_fin (R n x) = inr x.
Proof.
  induction n; intros; simpl.
  - auto.
  - find_higher_order_rewrite.
    auto.
Qed.

We used another custom tactic here, find_higher_order_rewrite, which searches the context for a universally quantified equality, and tries to rewrite with it everywhere without introducing any new subgoals. This avoids dependence on automatically generated hypothesis names. In all cases it is possible to replace a call to find_higher_order_rewrite with a call to rewrite by passing in the correct name. Other than this new tactic, the proofs are easy.

With fin_of_sum_to_sum_of_fin defined, we can now define the inverse of from_sum.

Definition to_sum {n m A B} (f : Fin.t n -> A) (g : Fin.t m -> B)
           (x : Fin.t (n + m)) : A + B :=
  match fin_of_sum_to_sum_of_fin x with
    | inl a => inl (f a)
    | inr b => inr (g b)
  end.

And we can prove that the two are inverses. In the first direction, we show that encoding an element of A + B into an element of Fin.t (n + m) and then decoding it returns the same element. We'll need to use our lemmas that relate fin_of_sum_to_sum_of_fin to L and R.

Lemma to_from_sum :
  forall n m A B f1 f2 g1 g2 x,
    (forall x, f2 (f1 x) = x) ->
    (forall x, g2 (g1 x) = x) ->
    @to_sum n m A B f2 g2 (from_sum f1 g1 x) = x.
Proof.
  unfold to_sum.
  intros.
  destruct x; simpl.
  - rewrite fin_of_sum_to_sum_of_fin_L.
    auto using f_equal.
  - rewrite fin_of_sum_to_sum_of_fin_R.
    auto using f_equal.
Qed.

In the other direction, we follow the structure of to_sum, doing case analysis on the result of fin_of_sum_to_sum_of_fin.

Lemma from_to_sum :
  forall n m A B f1 f2 g1 g2 x,
    (forall x, f1 (f2 x) = x) ->
    (forall x, g1 (g2 x) = x) ->
    from_sum f1 g1 (@to_sum n m A B f2 g2 x) = x.
Proof.
  intros.
  unfold to_sum.
  break_match; simpl.
  - find_higher_order_rewrite.

To continue, we need a lemma that says that if fin_of_sum_to_sum_of_fin returns an element in the "left side", then its argument must have been in the range of L. This is sort of the converse of fin_of_sum_to_sum_of_fin_L.

Lemma L_fin_of_sum_to_sum_of_fin :
  forall n m (x : Fin.t (n + m)) t,
    fin_of_sum_to_sum_of_fin x = inl t ->
    L m t = x.
Proof.
  induction n; intros.
  - solve_by_inversion.
  - simpl in *.
    revert H.
    fin_dep_destruct x.
    + solve_by_inversion.
    + simpl in *.
      break_match.
      * find_inversion.
        simpl.
        auto using f_equal.
      * discriminate.
Qed.

Here's yet another custom tactic, find_inversion, which searches the context for an equality to invert on. In this case, we have a hypothesis of the form inl x = inl y. Inverting this yields the equality x = y. Again, the purpose of this tactic is solely to eliminate dependence on hypothesis names. One could just as easily call inversion directly.

We can now finish the first subcase of the second direction.

    auto using L_fin_of_sum_to_sum_of_fin.

The second case starts out the same way

  - find_higher_order_rewrite.

To proceed, we'll need an analogous result about R.

Lemma R_fin_of_sum_to_sum_of_fin :
  forall n m (x : Fin.t (n + m)) t,
    fin_of_sum_to_sum_of_fin x = inr t ->
    R n t = x.
Proof.
  induction n; intros; simpl in *.
  - congruence.
  - revert H.
    fin_dep_destruct x.
    + discriminate.
    + simpl in *.
      break_match.
      * discriminate.
      * f_equal.
        solve_by_inversion.
Qed.

We can finally complete the proof of the second direction

   auto using R_fin_of_sum_to_sum_of_fin.
Qed.

Now that we know to_sum and from_sum are inverses, we can prove the desired result about cardinalities of sums.

Lemma card_sum :
  forall A B m n,
    cardinality A m ->
    cardinality B n ->
    cardinality (A + B) (m + n).
Proof.
  unfold cardinality.
  firstorder.
  eexists.
  eexists.
  eauto using to_from_sum, from_to_sum.
Qed.

We now prove that the cardinality of a product of types is the product of the cardinalities. First, we define the isomorphisms at the Fin.t level.

The first direction takes two Fin.ts and encodes them in a larger Fin.t. This is analogous to the L and R operations (from the standard library) that we used above to convert sums. Conceptually, we divide Fin.t (n * m) into n blocks of m elements each. The value of x tells us which block we're in, and the value of y tells us which element in the block. In the case where n = S n'' for some n'', (S n'') * m simplifies to m + n'' * m. This lets us use L and R to inject into the left or right side of the sum.

Fixpoint prod_of_fin_to_fin_of_prod {n m} {struct n} :
      Fin.t n ->  Fin.t m -> Fin.t (n * m) :=
  match n as n' return Fin.t n' -> Fin.t m -> Fin.t (n' * m) with
    | 0 => fun x _ => case0 _ x
    | S n'' =>
      fun x y =>
        fin_case _ x _
                 (L _ y)
                 (fun x' => R _ (prod_of_fin_to_fin_of_prod x' y))
  end.

In the other direction, when n = S n', we again have a Fin.t of a sum, and we use the conversion function for sums to decide whether the argument is in the "current" block (the left part of the sum) or some later block.

Fixpoint fin_of_prod_to_prod_of_fin {n m} {struct n} :
      Fin.t (n * m) -> Fin.t n * Fin.t m :=
  match n as n' return Fin.t (n' * m) -> Fin.t n' * Fin.t m with
    | 0 => fun x => case0 _ x
    | S n' =>
      fun x =>
        match fin_of_sum_to_sum_of_fin x with
          | inl a => (F1, a)
          | inr b => let (x,y) := fin_of_prod_to_prod_of_fin b
                     in (FS x, y)
        end
  end.

Then given encode/decode functions on the underlying types, we can define encode/decode functions on pairs.

Definition from_prod {n m A B} (f : A -> Fin.t n) (g : B -> Fin.t m)
           (x : A * B) : Fin.t (n * m) :=
  let (a,b) := x in
  prod_of_fin_to_fin_of_prod (f a) (g b).

Definition to_prod {n m A B} (f : Fin.t n -> A) (g : Fin.t m -> B)
           (x : Fin.t (n * m)) : A * B :=
  let (a,b) := fin_of_prod_to_prod_of_fin x in (f a, g b).

We now need lemmas showing these functions are inverses. Because the implementation uses some lemmas about encoding and decoding, we'll need to use the relevant lemmas in our proofs. We first show that the isomorphism between products of Fin.ts and Fin.ts of products is actually an isomorphism.

Lemma fin_prod_fin_inverse :
  forall n m x a b,
    @fin_of_prod_to_prod_of_fin n m x = (a,b) ->
    prod_of_fin_to_fin_of_prod a b = x.
Proof.
  induction n; intros.
  - solve_by_inversion.
  - simpl in *.
    fold mult in *.
    break_match.
    + find_inversion.
      simpl.
      auto using L_fin_of_sum_to_sum_of_fin.
    + break_match.
      find_inversion.
      simpl.
      erewrite IHn; eauto.
      auto using R_fin_of_sum_to_sum_of_fin.
Qed.

Lemma prod_fin_prod_inverse :
  forall n m a b,
    @fin_of_prod_to_prod_of_fin n m (prod_of_fin_to_fin_of_prod a b) = (a,b).
Proof.
  induction n; intros.
  - solve_by_inversion.
  - fin_dep_destruct a; simpl.
    + rewrite fin_of_sum_to_sum_of_fin_L in *.
      auto.
    + rewrite fin_of_sum_to_sum_of_fin_R in *.
      rewrite IHn.
      auto.
Qed.

We now show that the encode/decode functions are inverses.

Lemma from_to_prod :
  forall n m A B f1 f2 g1 g2 x,
    (forall y, f2 (f1 y) = y) ->
    (forall y, g2 (g1 y) = y) ->
    from_prod f2 g2 (@to_prod n m A B f1 g1 x) = x.
Proof.
  intros.
  unfold from_prod, to_prod.
  repeat break_match.
  find_inversion.
  find_higher_order_rewrite.
  find_higher_order_rewrite.
  auto using fin_prod_fin_inverse.
Qed.

Lemma to_from_prod :
  forall n m A B f1 f2 g1 g2 x,
    (forall y, f2 (f1 y) = y) ->
    (forall y, g2 (g1 y) = y) ->
    to_prod f2 g2 (@from_prod n m A B f1 g1 x) = x.
Proof.
  intros.
  unfold from_prod, to_prod.
  repeat break_match.
  subst.
  rewrite prod_fin_prod_inverse in *.
  find_inversion.
  find_higher_order_rewrite.
  find_higher_order_rewrite.
  auto.
Qed.

Finally we can show that the cardinality of a product is the product of the cardinalities.

Lemma card_prod :
  forall A B m n,
    cardinality A m ->
    cardinality B n ->
    cardinality (A * B) (m * n).
Proof.
  unfold cardinality.
  firstorder.
  eexists.
  eexists.
  eauto using to_from_prod, from_to_prod.
Qed.

We now have lemmas explaining how to express T as a sum of products, and about the cardinalities of unit, and sums and products of Fin.ts.

Theorem cardinality_T :
  forall A n,
    cardinality A n ->
    cardinality (T A) (1 + n + n * n).
Proof.
  auto using card_T_sum_prod, card_unit, card_sum, card_prod.
Qed.

Thanks again to osa1 from the Coq IRC channel for posing this problem. You can see his write up of the problem on his blog.