sortdef color = {c:nat | 0 <= c; c <= 1}

#define red 1
#define black 0

datatype rbtree(int, int, int) =
  | E(black, 0, 0)
  | {cl,cr:color} {bh:nat} 
     B(black, bh+1, 0)
       of (rbtree(cl, bh, 0), int, rbtree(cr, bh, 0))
  | {cl,cr:color} {bh:nat}
     R(red, bh, cl+cr)
       of (rbtree(cl, bh, 0), int, rbtree(cr, bh, 0))

fun restore {cl,cr:color} {bh:nat} {vl,vr:nat | vl+vr <= 1} (
  left: rbtree(cl, bh, vl),
  key: int,
  right: rbtree(cr, bh, vr)
  ): [c:color] rbtree(c, bh + 1, 0) =
  case+ (left, key, right) of
    | (R(R(a,x,b),y,c),z,d) => R(B(a,x,b),y,B(c,z,d))
    | (R(a,x,R(b,y,c)),z,d) => R(B(a,x,b),y,B(c,z,d))
    | (a,x,R(R(b,y,c),z,d)) => R(B(a,x,b),y,B(c,z,d))
    | (a,x,R(b,y,R(c,z,d))) => R(B(a,x,b),y,B(c,z,d))
    | (a,x,b) =>> B(a,x,b)

fun insert {c:color} {bh:nat} ( 
  x: int, 
  t: rbtree(c ,bh, 0)
  ): [c:color] [bh2:nat] rbtree(c, bh2, 0) = let
  fun ins {c:color} {bh:nat} (
    t2: rbtree(c, bh, 0)
  ):<cloptr1> [c2:color] [v:nat | v <= c] rbtree(c2, bh, v) =
    case+ t2 of
    | E () => R(E (), x, E ())
    | B(a, y, b) => if x < y then restore(ins(a), y, b)
                    else if y < x then restore (a, y, ins(b))
                    else B(a, y, b)
    | R(a, y, b) => if x < y then R(ins(a), y, b)
                    else if y < x then R(a, y, ins(b))
                    else R(a, y, b)
  val t = ins(t)
in
  case+ t of
  | R(a, y, b) => B(a, y, b)
  | _ =>> t
end 

fun print_tree {c:color} {bh:nat} (
  t: rbtree(c, bh, 0)
  ): void =
  case+ t of
  | E () => ()
  | R(a, y, b) => begin
                   print_tree(a);
                   printf("%d ", @(y));
                   print_tree(b);
                  end
  | B(a, y, b) => begin
                   print_tree(a);
                   printf("%d ", @(y));
                   print_tree(b);
                  end

implement main() = let
  val a = insert(5, E ())
  val b = insert(10, a)
  val c = insert(3, b)
  val d = insert(7, c)
in
  print_tree(d)
end