2012-08-30
Safer handling of C memory in ATS
In previous ATS posts I've written about how ATS can make using C functions safer by detecting violations of the C API's requirements at compile time. This post is a walkthrough of a simple example which involves a C api that copies data from a buffer of memory to one allocated by the caller. I'll start with an initial attempt with no safety beyond a C version and work through different options.
The API I'm using for the example is a base64 encoder from the stringencoders library. The C definition of the function is:
size_t modp_b64_encode(char* dest, const char* str, size_t len);
Given a string, str
, this will store the base64 encoded value of len
bytes of that string into dest
. dest
must be large enough to hold the result. The documentation for modp_b64_encode
states:
dest
should be allocated by the caller to contain at least((len+2)/3*4+1)
bytes. This will contain the null-terminated b64 encoded result.str
contains the byteslen
contains the number of bytes instr
- returns length of the destination string plus the ending null byte. i.e. the result will be equal to
strlen(dest) + 1
.
A first attempt
A first attempt at wrapping this in ATS is:
extern fun modp_b64_encode (dest: !strptr1, str: string, len: size_t): size_t = "mac#modp_b64_encode"
dest
is defined to be a linear non-null string. The !
means that the C function does not free the string and does not store it so we are responsible for allocating the memory and freeing it. The following function uses this API to
convert a string into a base64 encoded string:
extern fun string_to_base64 (s: string): strptr1
implement string_to_base64 (s) = let
val s_len = string_length (s) // 1
val d_len = size1_of_size ((s_len + 2) / 3 * 4) // 2
val (pfgc, pf_bytes | p_bytes) = malloc_gc (d_len + 1) // 3
val () = bytes_strbuf_trans (pf_bytes | p_bytes, d_len) // 4
val dest = strptr_of_strbuf @(pfgc, pf_bytes | p_bytes) // 5
val len = modp_b64_encode (dest, s, s_len) // 6
in
dest // 7
end
A line by line description of what this functions does follows:
- Get the length of the input string as
s_len
- Compute the length of the destination string as
d_len
. This does not include the null terminator. - Allocate
d_len
number of bytes plus one for the null terminator. This returns three things.pfgc
is a proof variable that is used to ensure we free the memory later.pf_bytes
is a proof that we haved_len+1
bytes allocated at a specific memory address. We need to provide this proof to other functions when we pass the pointer to the memory around so that the compiler can check we are using the memory correctly.p_bytes
is the raw pointer to the memory. We can't do much with this without the proofpf_bytes
saying what the pointer points too. - We have a proof that says we have a raw array of bytes. What we want to say is that this memory is actually a pointer to a null terminated string. In ATS this is called a
strbuf
. The functionbytes_strbuf_trans
converts the proofpf_bytes
from "byte array of length n" to "null terminated string of length n-1, with null terminator at n". - Now that we have a proof saying our pointer is a string buffer we can convert that to a
strptr1
. The functionstrptr_of_strbuf
does this conversion. It consumes thepfgc
andpf_bytes
and returns thestrptr1
. - Here we call our FFI function.
- Returns the resulting
strptr1
string_to_base64
is called with code like:
implement main () = let
val s = string_to_base64 "hello_world"
in
begin
print s;
print_newline ();
strptr_free s
end
end
Although this version of the FFI usage doesn't gain much safety over using it from C, there is some. I originally had the following and the code wouldn't type check:
val d_len = size1_of_size ((s_len + 2) / 3 * 4 + 1) // 2
val (pfgc, pf_bytes | p_bytes) = malloc_gc d_len // 3
val () = bytes_strbuf_trans (pf_bytes | p_bytes, d_len) // 4
In line 2 I include the extra byte for the null terminator. But line 4 uses this same length when converting the byte array proof into a proof of having a string buffer. What line 4 now says is we have a string buffer of length d_len
with a null terminator at d_len+1
. The proof pf_bytes
states we only have a buffer of length d_len
so fails to type check. The ability for ATS to typecheck lengths of arrays and knowing about requiring null terminators in string buffers saved the code from an off by one error here.
A version of this code is available in github gist 3522299. This can be cloned and used to work through the following examples trying out different approachs.
$ git clone git://gist.github.com/3522299.git example
$ cd example
$ make
$ ./base64
Typecheck the source buffer
One issue with this version of the wrapper is that we can pass an invalid length of the source string:
val len = modp_b64_encode (dest, s, 1000)
Since s
is less than length 1000
this will access memory out of bounds. This can be fixed by using dependent types to declare that the given length must match that of the source string length:
extern fun modp_b64_encode {n:nat} (dest: !strptr1,
str: string n,
len: size_t n): size_t = "mac#modp_b64_encode"
We declare that a sort n
exists that is a natural number. The str
argument is a string of length n
and the length passed is that same length. This is now a type error:
val s = string1_of_string "hello"
val len = modp_b64_encode (dest, s, 1000)
Unfortunately so is this:
val s = string1_of_string "hello"
val len = modp_b64_encode (dest, s, 2)
Ideally we want to be able to pass in a length of less than the total string length so we can base64 encode a subset of the source string. The following FFI declaration does this:
extern fun modp_b64_encode {n:nat} (dest: !strptr1,
str: string n,
len: sizeLte n): size_t = "mac#modp_b64_encode"
sizeLte
is a prelude type definition that's defined as:
typedef sizeLte (n: int) = [i:int | 0 <= i; i <= n] size_t (i)
It is a size_t
where the value is between 0
and n
inclusive. Our string_to_base64
function is very similar to what it was before but uses string1
functions on the string. string1
is a dependently typed string which includes the length in the type:
extern fun string_to_base64 {n:nat} (s: string n): strptr1
implement string_to_base64 (s) = let
val s_len = string1_length (s)
val d_len = (s_len + 2) / 3 * 4
val (pfgc, pf_bytes | p_bytes) = malloc_gc (d_len + 1)
val () = bytes_strbuf_trans (pf_bytes | p_bytes, d_len)
val dest = strptr_of_strbuf @(pfgc, pf_bytes | p_bytes)
val len = modp_b64_encode (dest, s, s_len)
in
dest
end
Typecheck the destination buffer
The definition for modp_b64_encode
defines the destination as being a linear string. This has no length defined at the type level so it's possible to pass a linear string that is too small and result in out of bounds memory access. this version of the FFI definition changes the type to a strbuf
.
As explained previously an strbuf
is a string buffer that is null terminated, It has two type level arguments:
abst@ype strbuf (bsz: int, len: int)
The first, bsz
is the length of the entire buffer. The second, len
is the length of the string. So bsz
should be greater than len
to account for the null terminator.
The new version of the modp_b64_encode
functions is quite a bit more complex so I'll go through it line by line:
extern fun modp_b64_encode
{l:agz} // 1
{n,bsz:nat | bsz >= (n + 2) / 3 * 4 + 1} // 2
(pf_dest: !b0ytes bsz @ l >> strbuf (bsz, rlen - 1) @ l | // 3
dest: ptr l, // 4
str: string n,
len: sizeLte n
): #[rlen:nat | rlen >= 1;rlen <= bsz] size_t rlen // 5
= "mac#modp_b64_encode"
Put simply, this function definition takes a pointer to a byte array enforced to be the correct size and after calling enforces that the pointer is now a pointer to a null terminated string with the correct length. In detail:
- This line declares a dependent type variable
l
of sortagz
. Sortagz
is an address greater than zero. In otherwords it's a non-null pointer address. It's used to allow the definition to reason about the memory address of the destination buffer. - Declare dependent type variables
n
andbsz
of sortnat
. Sortnat
is an integer greater than or equal to zero.bsz
is additionally constrained to be greator or equal to(n+2)/3*4+1
. This is used to ensure that the destination buffer is at least the correct size as defined by the documentation formodp_b64_encode
. In this way we can enforce the constraint at the type level. - All function arguments to the left of the
|
symbol in ATS are proof arguments. This line definespf_dest
which on input should be a proof that an array of bytes of lengthbsz
is held at memory addressl
(b0ytes
is a typedef for an array of uninitialized bytes). The!
states that this function does not consume the proof. The>>
means that after the function is called the type of the proof changes to the type on the right hand side of the>>
. In this case it's a string buffer at memory addressl
of total lengthbsz
, but the actual string is of lengthrlen-1
(the result length as described next). - The
dest
variable is now a simple pointer type pointing to memory addressl
. The proofpf_dest
describes what this pointer actually points to. - We define a result dependent type
rlen
for the return value. This is the length of the returned base64 encoded string plus one for the null terminator. We constrain this to be less than or equal tobsz
since it's not possible for it to be longer than the destination buffer. We also constrain it to be greater than or equal to one since a result of an empty string will return length one. The#
allows us to refer to this result dependent type in the function arguments which we do forpf_dest
to enforce the constraint that the string buffer length is that of the result length less one.
Our function to call this is a bit simpler:
extern fun string_to_base64 {n:nat} (s: string n): strptr1
implement string_to_base64 (s) = let
val s_len = string1_length (s)
val d_len = (s_len + 2) / 3 * 4 + 1
val (pfgc, pf_bytes | p_bytes) = malloc_gc d_len
val len = modp_b64_encode (pf_bytes | p_bytes, s, s_len)
in
strptr_of_strbuf @(pfgc, pf_bytes | p_bytes)
end
We no longer need to do the manual conversion to a strbuf
as the definition of modp_b64_encode
changes the type of the proof for us.
Handle errors
So far the definitions of modp_64_encode
have avoided the error result. If the length returned is -1
there was an error. In our last definition of the function we changed the type of the proof to a strbuf
after calling. What we actually need to do is only change the type if the function succeeded. On failure we must leave the proof the same as when the function was called. This will result in forcing the caller to check the result for success so they can unpack the proof into the correct type. Here's the new, even more complex, definition:
dataview encode_v (int, int, addr) =
| {l:agz} {bsz:nat} {rlen:int | rlen > 0; rlen <= bsz }
encode_v_succ (bsz, rlen, l) of strbuf (bsz, rlen - 1) @ l
| {l:agz} {bsz:nat} {rlen:int | rlen <= 0 }
encode_v_fail (bsz, rlen, l) of b0ytes bsz @ l
extern fun modp_b64_encode
{l:agz}
{n,bsz:nat | bsz >= (n + 2) / 3 * 4 + 1}
(pf_dest: !b0ytes bsz @ l >> encode_v (bsz, rlen, l) | // 1
dest: ptr l,
str: string n,
len: sizeLte n
): #[rlen:int | rlen <= bsz] size_t rlen
= "mac#modp_b64_encode"
First we define a view called encode_v
to encode the result that the destination buffer becomes. A dataview is like a datatype but is for proofs. It is erased after type checking is done. The view encode_v
is dependently typed over the size of the buffer bsz
, the length of the result, rlen
, and the memory address, l
.
The first view constructor, encode_v_succ
, is for when the result length is greater than zero. In that case the view contains a strbuf
. This is the equivalent of the case in our previous iteration where we only handled success.
The second view constructor, encode_v_fail
, is for the failure case. If the result length is less than or equal to zero then the view contains the original array of bytes.
In the line marked 1
above, we've changed the type the proof becomes to that of our view. Notice that the result length is one of the dependent types of the view. Now when calling modp_b64_encode
we must check the return type so we can unpack the view to get the correct proof. Here is the new calling code:
extern fun string_to_base64 {n:nat} (s: string n): strptr0
implement string_to_base64 (s) = let
val s_len = string1_length (s)
val d_len = (s_len + 2) / 3 * 4 + 1
val (pfgc, pf_bytes | p_bytes) = malloc_gc d_len
val len = modp_b64_encode (pf_bytes | p_bytes, s, s_len)
in
if len > 0 then let
prval encode_v_succ pf = pf_bytes // 1
in
strptr_of_strbuf @(pfgc, pf | p_bytes)
end
else let
prval encode_v_fail pf = pf_bytes // 2
val () = free_gc (pfgc, pf | p_bytes)
in
strptr_null ()
end
end
The code is similar to the previous iteration until after the modp_b64_encode
call. Once that call is made the proof pf_bytes
is now an encode_v
. This means we can no longer use the pointer p_bytes
as no proof explaining what exactly it is pointing to is in scope.
We need to unpack the proof by pattern matching against the view. To do this we branch on a conditional based on the value of the length returned by modp_b64_encode
. If this length is greater than zero we know pf_bytes
is a encode_v_succ
. We pattern match on it in the line marked 1 to extract our proof that p_bytes
is an strbuf
and can then turn that into a strptr
and return it.
If the length is not greater than zero we know that pf_bytes
must be an encode_v_fail
. We pattern match on this in the line marked 2, extracting the proof that it is the original array of bytes we allocated. This is free'd in the following line and we return a null strptr. The type of string_to_base64
has changed to allow returning a null pointer.
We know the types of the encode_v
proof based on the result length because this is encoded in the definition of encode_v
. In that dataview definition we state what the valid values rlen
is for each case. If our condition checks for the wrong value we get a type error on the pattern matching line. If we fail to handle the success or the failure case we get type errors for not consuming proofs correctly (eg. for not freeing the allocated memory on failure).
Conclusion
The result of our iterations ends up providing the following safety guarantees at compile time:
- We don't exceed the memory bounds of the source buffer
- We don't exceed the memory bounds of the destination buffer
- The destination buffer is at least the minimium size required by the function documentation
- We can't treat the destination buffer as a string if the function fails
- We can't treat the destination buffer as an array of bytes if the function succeeds
- Off by one errors due to null terminator handling are removed
- Checking to see if the function call failed is enforced
The complexity of the resulting definition is increased but someone familiar with ATS can read it and know what the function expects. The need to also read function documentation and find out things like the minimum size of the destination buffer is reduced.
When the ATS compiler compiles this code to C code the proof checking is removed. The resulting C code looks very much like handcoded C without runtime checks. Something like the following is generated:
char* string_to_base64(const char* s) {
int s_len = strlen (s);
int d_len = (s_len + 2) / 3 * 4 + 1;
void* p_bytes = malloc (d_len)
int len = modp_b64_encode (p_bytes, s, s_len);
if (len > 0)
return (char*)p_bytes;
else {
free(p_bytes);
return 0;
}
}