Skip to content
View tusharjois's full-sized avatar

Highlights

  • Pro

Organizations

@dove-project @spacelab-ccny

Block or report tusharjois

Block user

Prevent this user from interacting with your repositories and sending you notifications. Learn more about blocking users.

You must be logged in to block users.

Please don't include any personal information such as legal names or email addresses. Maximum 100 characters, markdown supported. This note will be visible to only you.
Report abuse

Contact GitHub support about this user’s behavior. Learn more about reporting abuse.

Report abuse

Pinned Loading

  1. socioty socioty Public

    Rust

  2. dove-project/benchmarks dove-project/benchmarks Public

    the benchmarks code and results for DOVE

    Python

  3. bscanf bscanf Public

    a standalone sscanf implementation with bounds checking

    C 16 2

  4. maxzinkus/PhoneEncryptionDocumentArchive maxzinkus/PhoneEncryptionDocumentArchive Public archive

    18

  5. meteor demo meteor demo
    1
    {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"Meteor Live Demo","private_outputs":true,"provenance":[],"collapsed_sections":[],"authorship_tag":"ABX9TyMHQxDUetlMCXtxtCQThhg9"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"Uje2sTLOnOwr"},"source":["# Meteor\n","\n","Cryptographically secure steganography for realistic distributions.\n","\n","A project by <a href=\"https://kaptchuk.com\">Gabe Kaptchuk</a>, <a href=\"http://tjo.is\">Tushar Jois</a>, <a href=\"https://isi.jhu.edu/~mgreen/\">Matthew Green</a>, and <a href=\"https://avirubin.com\">Avi Rubin</a>. Check out our [academic paper](https://eprint.iacr.org/2021/686) for more details.\n","\n","**Note**: This is a research implementation, and has not undergone the level of professional auditing required to be used as a production system. You may encounter some bugs. Reach out to [Tushar](https://twitter.com/tusharjois) if you're having trouble.\n","\n","## Getting started\n","\n","In the menu bar, go to **Runtime > Run all**. The output will appear [here](#scrollTo=MmzzrhXa57Vn). This may take a while on the initial run, since Colab has to download all of the necessary files for Meteor to run. However, future runs should be much faster -- mess with the parameters in this notebook (such as the [context](#scrollTo=b0w7H22MiiDD) and the [message](#scrollTo=tftXoMOBsSQ3)) and see what happens. Have fun!"]},{"cell_type":"code","metadata":{"id":"gKI8nADha43j"},"source":["#@title Colab setup { run: \"auto\", display-mode: \"form\" }\n","#@markdown This downloads some prereqs. It might take a while! You only have to run this cell once.\n","!pip install torch==1.0.1 pytorch-transformers==1.1.0 bitarray==1.0.1\n","import hashlib\n","import hmac\n","import numpy as np\n","\n","class DRBG(object):\n","    def __init__(self, key, seed):\n","        self.key = key\n","        self.val = b'\\x01' * 64\n","        self.reseed(seed)\n","\n","        self.byte_index = 0\n","        self.bit_index = 0\n","\n","    def hmac(self, key, val):\n","        return hmac.new(key, val, hashlib.sha512).digest()\n","\n","    def reseed(self, data=b''):\n","        self.key = self.hmac(self.key, self.val + b'\\x00' + data)\n","        self.val = self.hmac(self.key, self.val)\n","\n","        if data:\n","            self.key = self.hmac(self.key, self.val + b'\\x01' + data)\n","            self.val = self.hmac(self.key, self.val)\n","\n","    def generate_bits(self, n):\n","        xs = np.zeros(n, dtype=bool)\n","        for i in range(0,n):\n","            xs[i] = (self.val[self.byte_index] >> (7 - self.bit_index)) & 1\n","\n","            self.bit_index += 1\n","            if self.bit_index >= 8:\n","                self.bit_index = 0\n","                self.byte_index += 1\n","\n","            if self.byte_index >= 8:\n","                self.byte_index = 0\n","                self.val = self.hmac(self.key, self.val)\n","\n","        self.reseed()\n","        return xs\n","\n","#@title\n","\n","import torch\n","import numpy as np\n","import bitarray\n","\n","from pytorch_transformers import GPT2LMHeadModel, GPT2Tokenizer\n","\n","def decode(self, token_ids, **kwargs):\n","    filtered_tokens = self.convert_ids_to_tokens(token_ids)\n","    text = self.convert_tokens_to_string(filtered_tokens)\n","    return text\n","GPT2Tokenizer.decode = decode\n","\n","def _convert_token_to_id(self, token):\n","    return self.encoder.get(token, 0)\n","GPT2Tokenizer._convert_token_to_id = _convert_token_to_id\n","\n","\n","def limit_past(past):\n","    past = list(past)\n","    for i in range(len(past)):\n","        past[i] = past[i][:, :, :, -1022:]\n","    return past\n","\n","def kl(q, logq, logp):\n","    res = q*(logq-logp)/0.69315\n","    res[q==0] = 0\n","    return res.sum().item() # in bits\n","\n","def entropy(q, logq):\n","    res = q*logq/0.69315\n","    res[q==0] = 0\n","    return -res.sum().item() # in bits\n","\n","# e.g. [0, 1, 1, 1] looks like 1110=14\n","def bits2int(bits):\n","    res = 0\n","    for i, bit in enumerate(bits):\n","        res += bit*(2**i)\n","    return res\n","\n","def int2bits(inp, num_bits):\n","    if num_bits == 0:\n","        return []\n","    strlist = ('{0:0%db}'%num_bits).format(inp)\n","    return [int(strval) for strval in reversed(strlist)]\n","\n","def is_sent_finish(token_idx, enc):\n","    token = enc.decoder[token_idx]\n","    return '.' in token or '!' in token or '?' in token\n","\n","def num_same_from_beg(bits1, bits2):\n","    assert len(bits1) == len(bits2)\n","    for i in range(len(bits1)):\n","        if bits1[i] != bits2[i]:\n","            break\n","\n","    return i\n","\n","def encode_context(raw_text, enc):\n","    context_tokens = [enc.encoder['<|endoftext|>']] + enc.encode(raw_text)\n","    return context_tokens\n","\n","# Use gpt2-medium for 345M param model\n","# Use gpt2-large for 774M param model\n","def get_model(seed=1234, model_name='gpt2', device='cuda'):\n","    np.random.seed(seed)\n","    torch.random.manual_seed(seed)\n","    torch.cuda.manual_seed(seed)\n","\n","    enc = GPT2Tokenizer.from_pretrained(model_name)\n","    enc.unk_token = None\n","    enc.bos_token = None\n","    enc.eos_token = None\n","    \n","    model = GPT2LMHeadModel.from_pretrained(model_name)\n","    model.to(device)\n","    model.eval()\n","    # model.double()  # want to avoid using this\n","\n","    return enc, model\n","\n","enc32_itoc = ['\\0', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '.', ',', \"'\", '!', ' ']\n","enc32_ctoi = {k: v for v, k in enumerate(enc32_itoc)}\n","def enc32(text):\n","    bits = []\n","    for c in text:\n","        bits.extend(int2bits(enc32_ctoi[c], 5))\n","    return bits\n","\n","def dec32(bits):\n","    text = ''\n","    for i in range(0, len(bits), 5):\n","        c = enc32_itoc[bits2int(bits[i:i+5])]\n","        if c == '\\0':\n","            break\n","        text += c\n","    return text\n","\n","# message should be bit string\n","# encoded should be text string\n","def expansion_ratio(message, encoded):\n","    message_bits = len(message)\n","    encoded_ba = bitarray.bitarray()\n","    encoded_ba.frombytes(encoded.encode('utf-8'))\n","    encoded_bits = len(encoded_ba.tolist())\n","    return encoded_bits/message_bits\n","\n","#@title\n","\n","import torch\n","import math\n","import random\n","\n","def bin_sort(l, token_indices, total, entropy, device):\n","    #compute entropy for upper bound on the number of bins we need\n","\n","    bucket_size = total\n","    num_bins = 2**int(entropy+1)\n","    bucket_size = total / num_bins\n","\n","    bins = [torch.empty(0, dtype=torch.long, device=device)] * num_bins\n","    value_in_bins = [0] * num_bins\n","    space_left_after = [total - i*bucket_size for i in range(0,num_bins)]\n","\n","\n","    token_bins = [torch.empty(0, dtype=torch.long, device=device)] * num_bins\n","\n","    # Figuring out what the search order should be\n","    step_size = num_bins/4\n","    search_order = []\n","    priorities = [0]*num_bins\n","    priority = 0\n","    search_order.append(int(num_bins/2))\n","    search_order.append(0)\n","    priorities[int(num_bins/2)] = 0\n","    priorities[0] = 0\n","    while(step_size>=1):\n","        priority += 1\n","        for x in range(num_bins-int(step_size), -1, -int(step_size*2)):\n","            search_order.append(x)\n","            priorities[x] = priority\n","        step_size = step_size/2\n","\n","    # Adding the actual elements\n","    for (item, token_index) in zip(l.tolist(), token_indices.tolist()):\n","        found_single_bucket_fit = False\n","        single_bucket_index = -1\n","        single_bucket_value = bucket_size\n","\n","        found_multi_bucket_bumpless_fit = False\n","        multi_bucket_bumpless_index = -1\n","        multi_bucket_bumpless_value = total\n","\n","        found_multi_bucket_bumping_fit = False\n","        multi_bucket_bumping_index = -1\n","        multi_bucket_bumping_value = total\n","\n","        for i in search_order:  # for index in search_order\n","            if(item > space_left_after[i]):\n","                continue\n","            if(value_in_bins[i] >= bucket_size):\n","                continue\n","\n","            # Priority of choices\n","            #  1. Can i place this thing in an empty bucket all on its own?\n","            #  2. Can i plan this somewhere where is doesnt have to bump anything else around?\n","            #    2a. Minimize the wasted space.  Aka use the smallest space (of equal priority) that accomplishes this goal\n","            #  3. If not (1) and (2), then put it in the space the bumps stuff the least.\n","\n","            if(value_in_bins[i] + item > bucket_size): #Would overflow. \n","\n","                space_before_next_block = bucket_size - value_in_bins[i]\n","                for j in range(i+1, len(bins)):\n","                    if(value_in_bins[j] > 0): # We have found a bucket with something in it.  This is how much space we have here.\n","                        space_before_next_block = space_before_next_block + (bucket_size - value_in_bins[i])\n","                        break\n","                    else: # This was a empty bucket\n","                        space_before_next_block = space_before_next_block + bucket_size\n","\n","                if((not found_multi_bucket_bumpless_fit) or (found_multi_bucket_bumpless_fit and priorities[i] <= priorities[multi_bucket_bumpless_index])): #This could potentially be a match\n","\n","                    # If this is a valid space to put this without bumping and it is a better fit than previous spaces\n","                    if(space_before_next_block > item and space_before_next_block < multi_bucket_bumpless_value):\n","                        # set this to be the pointer!  we can fit stuff here\n","                        found_multi_bucket_bumpless_fit = True\n","                        multi_bucket_bumpless_index = i\n","                        multi_bucket_bumpless_value = space_before_next_block\n","\n","                    # Find the overflow that will bump the least\n","                    if ( item - space_before_next_block < multi_bucket_bumping_value):\n","                        found_multi_bucket_bumping_fit = True\n","                        multi_bucket_bumping_index = i\n","                        multi_bucket_bumping_value = item - space_before_next_block\n","\n","            if(value_in_bins[i] + item <= bucket_size): #Would fit\n","                if(single_bucket_value > value_in_bins[i]):\n","                    found_single_bucket_fit = True\n","                    single_bucket_value = value_in_bins[i]\n","                    single_bucket_index = i\n","\n","        if (single_bucket_index == multi_bucket_bumpless_index == multi_bucket_bumping_index == -1):\n","            bins[0] = torch.cat( (torch.tensor([item], device=device), bins[0]), 0)\n","            token_bins[0] = torch.cat( (torch.tensor([token_index], device=device), token_bins[0]), 0)\n","            continue\n","\n","\n","        if found_single_bucket_fit:\n","            # We found somewhere we can actually fit!\n","            bins[single_bucket_index] = torch.cat( (bins[single_bucket_index], torch.tensor([item], device=device)), 0)  \n","            token_bins[single_bucket_index] = torch.cat( (token_bins[single_bucket_index], torch.tensor([token_index], device=device)), 0)  \n","            value_in_bins[single_bucket_index] += item\n","            for i in range(0, single_bucket_index+1):\n","                space_left_after[i] -= item\n","\n","        elif found_multi_bucket_bumpless_fit:\n","            # Found somewhere we can put this without upsetting the force\n","            part_in_bucket = bucket_size - value_in_bins[multi_bucket_bumpless_index]\n","            part_overflow = item - part_in_bucket\n","            bins[multi_bucket_bumpless_index] = torch.cat( (bins[multi_bucket_bumpless_index], torch.tensor([item], device=device)), 0)\n","            token_bins[multi_bucket_bumpless_index] = torch.cat( (token_bins[multi_bucket_bumpless_index], torch.tensor([token_index], device=device)), 0)  \n","            value_in_bins[multi_bucket_bumpless_index] = bucket_size\n","\n","            # Fill this bucket and continue overflowing\n","            j = multi_bucket_bumpless_index + 1\n","            for i in range(0, j):\n","                space_left_after[i] -= item\n","\n","            while(part_overflow > 0):\n","                new_part_overflow = (value_in_bins[j] + part_overflow) - bucket_size\n","                value_in_bins[j] = min(bucket_size, part_overflow+value_in_bins[j]) # mark the bucket as filled\n","                space_left_after[j] -= part_overflow\n","                part_overflow = new_part_overflow\n","                j+=1\n","\n","        else:\n","            part_in_bucket = bucket_size - value_in_bins[multi_bucket_bumping_index]\n","            part_overflow = item - part_in_bucket\n","            bins[multi_bucket_bumping_index] = torch.cat( (bins[multi_bucket_bumping_index], torch.tensor([item], device=device)), 0)\n","            token_bins[multi_bucket_bumping_index] = torch.cat( (token_bins[multi_bucket_bumping_index], torch.tensor([token_index], device=device)), 0)\n","            value_in_bins[multi_bucket_bumping_index] = bucket_size\n","\n","            # Fill this bucket and continue overflowing\n","            j = multi_bucket_bumping_index + 1\n","            for i in range(0, j):\n","                space_left_after[i] -= item\n","            while(part_overflow > 0):\n","                new_part_overflow = (value_in_bins[j] + part_overflow) - bucket_size\n","                value_in_bins[j] = min(bucket_size, part_overflow+value_in_bins[j]) # mark the bucket as filled\n","                space_left_after[j] -= part_overflow\n","                part_overflow = new_part_overflow\n","                j+=1\n","\n","\n","\n","    sorted_tensor = torch.cat(bins, 0)\n","    sorted_tokens = torch.cat(token_bins, 0)\n","\n","    return sorted_tensor, sorted_tokens\n","\n","\n","def compute_ev(t, precision):\n","    expected_bits = []\n","    cum_probs = t.cumsum(0)\n","\n","    for selection in range(0, len(cum_probs)):\n","\n","        # Calculate new range as ints\n","        new_int_bottom = cum_probs[selection-1] if selection > 0 else 0\n","        new_int_top = cum_probs[selection]\n","\n","        # Convert range to bits\n","        new_int_bottom_bits_inc = list(reversed(int2bits(new_int_bottom, precision)))\n","        new_int_top_bits_inc = list(reversed(int2bits(new_int_top-1, precision))) # -1 here because upper bound is exclusive\n","\n","        # Consume most significant bits which are now fixed and update interval\n","        num_bits_encoded = num_same_from_beg(new_int_bottom_bits_inc, new_int_top_bits_inc)\n","        expected_bits.append(t[selection] * num_bits_encoded)\n","\n","    return(float(sum(expected_bits).item())/(2**precision))\n","\n","def visualize_bins(values_in_bins, bucket_size):\n","    out_str = \"[\"\n","    for b in values_in_bins:\n","        out_str = out_str + \"  \" + str(round(100*b/bucket_size,2)) +  \"  |\"\n","    out_str = out_str + \"]\"\n","    print(out_str)\n","\n","def visualize_distribution(l):\n","    total = sum(l)\n","    out_str = \"[\"\n","    for b in l:\n","        out_str = out_str + \"  \" + str(round(100*b/total,2)) +  \"  |\"\n","    out_str = out_str + \"]\"\n","    print(out_str) \n","\n","def compute_entropy(lists):\n","    total = sum(lists)\n","    entropy = -1*sum([ (x/total) * math.log2(x/total) for x in lists])\n","    return entropy\n","\n","\n","#@title\n","import torch\n","import torch.nn.functional as F\n","\n","import os\n","\n","# Constants for HMAC-DRBG -- MUST CHANGE FOR SECURE IMPLEMENTATION\n","sample_key = b'0x01'*64\n","sample_seed_prefix = b'sample'\n","sample_nonce_counter = b'\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00'\n","\n","\n","def encode_meteor(model, enc, message, context, finish_sent=False, device='cuda', temp=1.0, precision=16, topk=50000, is_sort=False, randomize_key=False, input_key=sample_key, input_nonce=sample_nonce_counter):\n","\n","    if randomize_key:\n","        input_key = os.urandom(64)\n","    mask_generator = DRBG(input_key, sample_seed_prefix + input_nonce)\n","    context = torch.tensor(context[-1022:], device=device, dtype=torch.long)\n","\n","\n","    max_val = 2**precision\n","    threshold = 2**(-precision)\n","    cur_interval = [0, max_val] # bottom inclusive, top exclusive\n","\n","    prev = context\n","    output = context\n","    past = None\n","\n","    total_num = 0\n","    total_num_for_stats = 0\n","    total_log_probs = 0\n","    total_kl = 0 # in bits\n","    total_entropy_ptau = 0\n","    total_num_sents = 0\n","\n","    with torch.no_grad():\n","        i = 0\n","        sent_finish = False\n","        while i < len(message) or (finish_sent and not sent_finish):\n","            logits, past = model(prev.unsqueeze(0), past=past)\n","            past = limit_past(past)\n","            logits[0, -1, -1] = -1e20 # endoftext token can't happen\n","            logits[0, -1, 628] = -1e20 # 2 newlines token can't happen\n","            logits, indices = logits[0, -1, :].sort(descending=True)\n","            logits = logits.double()\n","            logits_temp = logits / temp\n","            probs_temp = F.softmax(logits_temp, dim=0)\n","            log_probs_temp = F.log_softmax(logits_temp, dim=0)\n","            log_probs = F.log_softmax(logits, dim=0)\n","\n","            # conditions for having reached the end of the message\n","            if i >= len(message):\n","                selection = 0\n","                sent_finish = is_sent_finish(indices[selection].item(), enc)\n","            else:\n","                # Cutoff low probabilities that would be rounded to 0\n","                cur_int_range = cur_interval[1]-cur_interval[0]\n","                cur_threshold = 1/cur_int_range\n","                k = min(max(2, (probs_temp < cur_threshold).nonzero()[0].item()), topk)\n","                probs_temp_int = probs_temp[:k] # Cutoff all but top k\n","                old_indices = indices\n","                indices = indices[:k]\n","\n","                # Rescale to correct range\n","                probs_temp_int = probs_temp_int/probs_temp_int.sum()*cur_int_range\n","\n","                entropy_in_this_distribution = entropy(probs_temp, log_probs_temp)\n","\n","                # Round probabilities to integers given precision\n","                probs_temp_int = probs_temp_int.round().long()\n","\n","                if is_sort:\n","                    probs_temp_int, indices = bin_sort(probs_temp_int, indices, cur_int_range, entropy_in_this_distribution, device)\n","                cum_probs = probs_temp_int.cumsum(0)\n","\n","                # Remove any elements from the bottom if rounding caused the total prob to be too large\n","                overfill_index = (cum_probs > cur_int_range).nonzero()\n","                if len(overfill_index) > 0:\n","                    cum_probs = cum_probs[:overfill_index[0]]\n","\n","                # Add any mass to the top if removing/rounding causes the total prob to be too small\n","                cum_probs += cur_int_range-cum_probs[-1] # add\n","\n","                # Get out resulting probabilities\n","                probs_final = cum_probs.clone()\n","                probs_final[1:] = cum_probs[1:] - cum_probs[:-1]\n","\n","                # Convert to position in range\n","                cum_probs += cur_interval[0]\n","\n","                # Apply the mask to the message\n","                message_bits = message[i:i+precision]\n","                if i+precision > len(message):\n","                    message_bits = message_bits + [0]*(i+precision-len(message))\n","\n","                mask_bits = mask_generator.generate_bits(precision)\n","                for b in range(0, len(message_bits)):\n","                    message_bits[b] = message_bits[b] ^ mask_bits[b]\n","\n","                # Get selected index based on binary fraction from message bits\n","                message_idx = bits2int(reversed(message_bits))\n","                selection = (cum_probs > message_idx).nonzero()[0].item()\n","\n","                # Calculate new range as ints\n","                new_int_bottom = cum_probs[selection-1] if selection > 0 else cur_interval[0]\n","                new_int_top = cum_probs[selection]\n","\n","                # Convert range to bits\n","                new_int_bottom_bits_inc = list(reversed(int2bits(new_int_bottom, precision)))\n","                new_int_top_bits_inc = list(reversed(int2bits(new_int_top-1, precision))) # -1 here because upper bound is exclusive\n","\n","                # Consume most significant bits which are now fixed and update interval\n","                num_bits_encoded = num_same_from_beg(new_int_bottom_bits_inc, new_int_top_bits_inc)\n","                i += num_bits_encoded\n","\n","                # Gather statistics\n","                total_log_probs += log_probs[selection].item()\n","\n","                q = probs_final.double()/probs_final.sum()\n","                logq = q.log()\n","                total_kl += kl(q, logq, log_probs[:len(q)])\n","                total_entropy_ptau += entropy_in_this_distribution\n","                total_num_for_stats += 1\n","\n","            # Update history with new token\n","            prev = indices[selection].view(1)\n","            output = torch.cat((output, prev))\n","            total_num += 1\n","\n","            # For text->bits->text\n","            partial = enc.decode(output[len(context):].tolist())\n","            if '<eos>' in partial:\n","                break\n","\n","    avg_NLL = -total_log_probs/total_num_for_stats\n","    avg_KL = total_kl/total_num_for_stats\n","    avg_Hq = total_entropy_ptau/total_num_for_stats\n","    words_per_bit = total_num_for_stats/i\n","\n","    return output[len(context):].tolist(), avg_NLL, avg_KL, words_per_bit, avg_Hq\n","\n","def decode_meteor(model, enc, text, context, device='cuda', temp=1.0, precision=16, topk=50000, is_sort=False, input_key=sample_key, input_nonce=sample_nonce_counter):\n","    # inp is a list of token indices\n","    # context is a list of token indices\n","    inp = enc.encode(text)\n","\n","    context = torch.tensor(context[-1022:], device=device, dtype=torch.long)\n","    mask_generator = DRBG(input_key, sample_seed_prefix + input_nonce)\n","\n","    max_val = 2**precision\n","    threshold = 2**(-precision)\n","    cur_interval = [0, max_val] # bottom inclusive, top exclusive\n","\n","    prev = context\n","    past = None\n","    message = []\n","    with torch.no_grad():\n","        i = 0\n","        while i < len(inp):\n","            logits, past = model(prev.unsqueeze(0), past=past)\n","            past = limit_past(past)\n","            logits[0, -1, -1] = -1e20 # endoftext can't happen\n","            logits[0, -1, 628] = -1e20 # 2 newlines can't happen\n","            logits, indices = logits[0, -1, :].sort(descending=True)\n","            logits = logits.double()\n","            logits_temp = logits / temp\n","            log_probs_temp = F.log_softmax(logits_temp, dim=0)\n","            probs_temp = F.softmax(logits_temp, dim=0)\n","\n","            # Cutoff low probabilities that would be rounded to 0\n","            cur_int_range = cur_interval[1]-cur_interval[0]\n","            cur_threshold = 1/cur_int_range\n","            k = min(max(2, (probs_temp < cur_threshold).nonzero()[0].item()), topk)\n","            probs_temp_int = probs_temp[:k] # Cutoff all but top k\n","\n","            # Rescale to correct range\n","            probs_temp_int = probs_temp_int/probs_temp_int.sum()*cur_int_range\n","            entropy_in_this_distribution = entropy(probs_temp, log_probs_temp)\n","\n","            # Round probabilities to integers given precision\n","            probs_temp_int = probs_temp_int.round().long()\n","            if is_sort:\n","                probs_temp_int, indices = bin_sort(probs_temp_int, indices, cur_int_range, entropy_in_this_distribution, device)\n","            cum_probs = probs_temp_int.cumsum(0)\n","\n","            # Remove any elements from the bottom if rounding caused the total prob to be too large\n","            overfill_index = (cum_probs > cur_int_range).nonzero()\n","            if len(overfill_index) > 0:\n","                cum_probs = cum_probs[:overfill_index[0]]\n","                k = overfill_index[0].item()\n","\n","            # Add any mass to the top if removing/rounding causes the total prob to be too small\n","            cum_probs += cur_int_range-cum_probs[-1] # add\n","\n","            # Covnert to position in range\n","            cum_probs += cur_interval[0]\n","\n","            rank = (indices == inp[i]).nonzero().item()\n","\n","            # Handle most errors that could happen because of BPE with heuristic\n","            if rank >= k:\n","                true_token_text = enc.decoder[inp[i]]\n","                for rank_idx in range(k):\n","                    prop_token_text = enc.decoder[indices[rank_idx].item()]\n","                    # common case that is not caught\n","                    if inp[i] == 128 and indices[rank_idx] == 198:\n","                        rank = rank_idx\n","                        inp[i] = indices[rank_idx].item()\n","                        break\n","            \n","                    # Is there a more likely prefix token that could be the actual token generated?\n","                    if len(prop_token_text) <= len(true_token_text) and \\\n","                            prop_token_text == true_token_text[:len(prop_token_text)]:\n","                        rank = rank_idx\n","                        suffix = true_token_text[len(prop_token_text):]\n","                        suffix_tokens = enc.encode(suffix) # a list\n","                        inp[i] = indices[rank_idx].item()\n","                        inp[i+1:i+1] = suffix_tokens # insert suffix tokens into list\n","                        break\n","\n","                    # Is there a more likely longer token that could be the actual token generated?\n","                    elif len(prop_token_text) > len(true_token_text) and \\\n","                              true_token_text == prop_token_text[:len(true_token_text)]:\n","                        whole_text = true_token_text\n","                        num_extra = 1\n","                        while len(whole_text) < len(prop_token_text):\n","                            whole_text += enc.decoder[inp[i+num_extra]]\n","                            num_extra += 1\n","                        if prop_token_text == whole_text[:len(prop_token_text)]:\n","                            rank = rank_idx\n","                            inp[i] = indices[rank_idx].item()\n","                            for j in range(1, num_extra):\n","                                del inp[i+j]\n","\n","                            if len(whole_text) > len(prop_token_text):\n","                                suffix = whole_text[len(prop_token_text):]\n","                                suffix_tokens = enc.encode(suffix) # a list\n","                                inp[i+1:i+1] = suffix_tokens # insert suffix tokens into list\n","                            break\n","                else:\n","                    print('Unable to fix BPE error: token received: %s=%d, text: %s' % (true_token_text, inp[i], text))\n","                    rank = 0\n","\n","            selection = rank\n","\n","            # Calculate new range as ints\n","            new_int_bottom = cum_probs[selection-1] if selection > 0 else cur_interval[0]\n","            new_int_top = cum_probs[selection]\n","\n","            # Convert range to bits\n","            new_int_bottom_bits_inc = list(reversed(int2bits(new_int_bottom, precision)))\n","            new_int_top_bits_inc = list(reversed(int2bits(new_int_top-1, precision))) # -1 here because upper bound is exclusive\n","\n","            # Emit most significant bits which are now fixed and update interval\n","            num_bits_encoded = num_same_from_beg(new_int_bottom_bits_inc, new_int_top_bits_inc)\n","            if i == len(inp)-1:\n","                new_bits = new_int_bottom_bits_inc\n","            else:\n","                new_bits = new_int_top_bits_inc[:num_bits_encoded]\n","\n","            # Get the mask and apply it to the recovered bits\n","            mask_bits = mask_generator.generate_bits(precision)\n","            for b in range(0, len(new_bits)):\n","                new_bits[b] = new_bits[b] ^ mask_bits[b]\n","            message += new_bits\n","\n","            # Update history with new token\n","            prev = torch.tensor([inp[i]], device=device, dtype=torch.long)\n","\n","            i += 1\n","\n","    return message\n","\n","def encode_arithmetic(model, enc, message, context, finish_sent=False, device='cuda', temp=1.0, precision=16, topk=50000):\n","\n","    context = torch.tensor(context[-1022:], device=device, dtype=torch.long)\n","\n","    max_val = 2**precision\n","    threshold = 2**(-precision)\n","    cur_interval = [0, max_val] # bottom inclusive, top exclusive\n","\n","    prev = context\n","    output = context\n","    past = None\n","\n","    total_num = 0\n","    total_num_for_stats = 0\n","    total_log_probs = 0\n","    total_kl = 0 # in bits\n","    total_entropy_ptau = 0\n","    total_num_sents = 0\n","\n","    with torch.no_grad():\n","        i = 0\n","        sent_finish = False\n","        while i < len(message) or (finish_sent and not sent_finish):\n","            logits, past = model(prev.unsqueeze(0), past=past)\n","            past = limit_past(past)\n","            logits[0, -1, -1] = -1e20 # endoftext token can't happen\n","            logits[0, -1, 628] = -1e20 # 2 newlines token can't happen\n","            logits, indices = logits[0, -1, :].sort(descending=True)\n","            logits = logits.double()\n","            logits_temp = logits / temp\n","            probs_temp = F.softmax(logits_temp, dim=0)\n","            log_probs_temp = F.log_softmax(logits_temp, dim=0)\n","            log_probs = F.log_softmax(logits, dim=0)\n","            \n","            # conditions for having reached the end of the message\n","            if i >= len(message):\n","                selection = 0\n","                sent_finish = is_sent_finish(indices[selection].item(), enc)\n","            else:\n","                # Cutoff low probabilities that would be rounded to 0\n","                cur_int_range = cur_interval[1]-cur_interval[0]\n","                cur_threshold = 1/cur_int_range\n","                k = min(max(2, (probs_temp < cur_threshold).nonzero()[0].item()), topk)\n","                probs_temp_int = probs_temp[:k] # Cutoff all but top k\n","\n","                # Rescale to correct range\n","                probs_temp_int = probs_temp_int/probs_temp_int.sum()*cur_int_range\n","\n","                # Round probabilities to integers given precision\n","                probs_temp_int = probs_temp_int.round().long()\n","                cum_probs = probs_temp_int.cumsum(0)\n","\n","                # Remove any elements from the bottom if rounding caused the total prob to be too large\n","                overfill_index = (cum_probs > cur_int_range).nonzero()\n","                if len(overfill_index) > 0:\n","                    cum_probs = cum_probs[:overfill_index[0]]\n","\n","                # Add any mass to the top if removing/rounding causes the total prob to be too small\n","                cum_probs += cur_int_range-cum_probs[-1] # add\n","\n","                # Get out resulting probabilities\n","                probs_final = cum_probs.clone()\n","                probs_final[1:] = cum_probs[1:] - cum_probs[:-1]\n","\n","                # Convert to position in range\n","                cum_probs += cur_interval[0]\n","\n","                # Get selected index based on binary fraction from message bits\n","                message_bits = message[i:i+precision]\n","                if i+precision > len(message):\n","                    message_bits = message_bits + [0]*(i+precision-len(message))\n","                message_idx = bits2int(reversed(message_bits))\n","                selection = (cum_probs > message_idx).nonzero()[0].item()\n","\n","                # Calculate new range as ints\n","                new_int_bottom = cum_probs[selection-1] if selection > 0 else cur_interval[0]\n","                new_int_top = cum_probs[selection]\n","\n","                # Convert range to bits\n","                new_int_bottom_bits_inc = list(reversed(int2bits(new_int_bottom, precision)))\n","                new_int_top_bits_inc = list(reversed(int2bits(new_int_top-1, precision))) # -1 here because upper bound is exclusive\n","\n","                # Consume most significant bits which are now fixed and update interval\n","                num_bits_encoded = num_same_from_beg(new_int_bottom_bits_inc, new_int_top_bits_inc)\n","                i += num_bits_encoded\n","\n","                new_int_bottom_bits = new_int_bottom_bits_inc[num_bits_encoded:] + [0]*num_bits_encoded\n","                new_int_top_bits = new_int_top_bits_inc[num_bits_encoded:] + [1]*num_bits_encoded\n","\n","                cur_interval[0] = bits2int(reversed(new_int_bottom_bits))\n","                cur_interval[1] = bits2int(reversed(new_int_top_bits))+1 # +1 here because upper bound is exclusive\n","\n","                # Gather statistics\n","                total_log_probs += log_probs[selection].item()\n","\n","                q = probs_final.double()/probs_final.sum()\n","                logq = q.log()\n","                total_kl += kl(q, logq, log_probs[:len(q)])\n","                total_entropy_ptau += entropy(probs_temp, log_probs_temp)\n","                total_num_for_stats += 1\n","            \n","            # Update history with new token\n","            prev = indices[selection].view(1)\n","            output = torch.cat((output, prev))\n","            total_num += 1\n","            \n","            # For text->bits->text\n","            partial = enc.decode(output[len(context):].tolist())\n","            if '<eos>' in partial:\n","                break\n","            \n","    avg_NLL = -total_log_probs/total_num_for_stats\n","    avg_KL = total_kl/total_num_for_stats\n","    avg_Hq = total_entropy_ptau/total_num_for_stats\n","    words_per_bit = total_num_for_stats/i\n","\n","    return output[len(context):].tolist(), avg_NLL, avg_KL, words_per_bit, avg_Hq\n","\n","def decode_arithmetic(model, enc, text, context, device='cuda', temp=1.0, precision=16, topk=50000):\n","    # inp is a list of token indices\n","    # context is a list of token indices\n","    inp = enc.encode(text)\n","    # common BPE error case: 128, 128 (2 newlines) is interpretted as 628 (2 newlines)\n","    i = 0\n","    while i < len(inp):\n","        if inp[i] == 628:\n","            inp[i] = 198\n","            inp[i+1:i+1] = [198]\n","            i += 2\n","        else:\n","            i += 1\n","\n","    context = torch.tensor(context[-1022:], device=device, dtype=torch.long)\n","\n","    max_val = 2**precision\n","    threshold = 2**(-precision)\n","    cur_interval = [0, max_val] # bottom inclusive, top exclusive\n","\n","    prev = context\n","    past = None\n","    message = []\n","    with torch.no_grad():\n","        i = 0\n","        while i < len(inp):\n","            logits, past = model(prev.unsqueeze(0), past=past)\n","            past = limit_past(past)\n","            logits[0, -1, -1] = -1e10 # endoftext can't happen\n","            logits[0, -1, 628] = -1e10 # 2 newlines can't happen\n","            logits, indices = logits[0, -1, :].sort(descending=True)\n","            logits = logits.double()\n","            logits_temp = logits / temp\n","            probs_temp = F.softmax(logits_temp, dim=0)\n","            \n","            # Cutoff low probabilities that would be rounded to 0\n","            cur_int_range = cur_interval[1]-cur_interval[0]\n","            cur_threshold = 1/cur_int_range\n","            k = min(max(2, (probs_temp < cur_threshold).nonzero()[0].item()), topk)\n","            probs_temp_int = probs_temp[:k] # Cutoff all but top k\n","\n","            # Rescale to correct range\n","            probs_temp_int = probs_temp_int/probs_temp_int.sum()*cur_int_range\n","\n","            # Round probabilities to integers given precision\n","            probs_temp_int = probs_temp_int.round().long()\n","            cum_probs = probs_temp_int.cumsum(0)\n","\n","            # Remove any elements from the bottom if rounding caused the total prob to be too large\n","            overfill_index = (cum_probs > cur_int_range).nonzero()\n","            if len(overfill_index) > 0:\n","                cum_probs = cum_probs[:overfill_index[0]]\n","                k = overfill_index[0].item()\n","\n","            # Add any mass to the top if removing/rounding causes the total prob to be too small\n","            cum_probs += cur_int_range-cum_probs[-1] # add\n","\n","            # Covnert to position in range\n","            cum_probs += cur_interval[0]\n","\n","            rank = (indices == inp[i]).nonzero().item()\n","\n","            # Handle most errors that could happen because of BPE with heuristic\n","            if rank >= k:\n","                true_token_text = enc.decoder[inp[i]]\n","                for rank_idx in range(k):\n","                    prop_token_text = enc.decoder[indices[rank_idx].item()]\n","                    # common case that is not caught\n","                    if inp[i] == 128 and indices[rank_idx] == 198:\n","                        rank = rank_idx\n","                        inp[i] = indices[rank_idx].item()\n","                        break\n","                    \n","                    # Is there a more likely prefix token that could be the actual token generated?\n","                    if len(prop_token_text) <= len(true_token_text) and \\\n","                            prop_token_text == true_token_text[:len(prop_token_text)]:\n","                        rank = rank_idx\n","                        suffix = true_token_text[len(prop_token_text):]\n","                        suffix_tokens = enc.encode(suffix) # a list\n","                        inp[i] = indices[rank_idx].item()\n","                        inp[i+1:i+1] = suffix_tokens # insert suffix tokens into list\n","                        break\n","\n","                    # Is there a more likely longer token that could be the actual token generated?\n","                    elif len(prop_token_text) > len(true_token_text) and \\\n","                              true_token_text == prop_token_text[:len(true_token_text)]:\n","                        whole_text = true_token_text\n","                        num_extra = 1\n","                        while len(whole_text) < len(prop_token_text):\n","                            whole_text += enc.decoder[inp[i+num_extra]]\n","                            num_extra += 1\n","                        if prop_token_text == whole_text[:len(prop_token_text)]:\n","                            rank = rank_idx\n","                            inp[i] = indices[rank_idx].item()\n","                            for j in range(1, num_extra):\n","                                del inp[i+j]\n","\n","                            if len(whole_text) > len(prop_token_text):\n","                                suffix = whole_text[len(prop_token_text):]\n","                                suffix_tokens = enc.encode(suffix) # a list\n","                                inp[i+1:i+1] = suffix_tokens # insert suffix tokens into list\n","                            break\n","                else:\n","                    print('Unable to fix BPE error: token received: %s=%d, text: %s' % (true_token_text, inp[i], text))\n","                    rank = 0\n","            \n","            selection = rank\n","            \n","            # Calculate new range as ints\n","            new_int_bottom = cum_probs[selection-1] if selection > 0 else cur_interval[0]\n","            new_int_top = cum_probs[selection]\n","\n","            # Convert range to bits\n","            new_int_bottom_bits_inc = list(reversed(int2bits(new_int_bottom, precision)))\n","            new_int_top_bits_inc = list(reversed(int2bits(new_int_top-1, precision))) # -1 here because upper bound is exclusive\n","            \n","            # Emit most significant bits which are now fixed and update interval\n","            num_bits_encoded = num_same_from_beg(new_int_bottom_bits_inc, new_int_top_bits_inc)\n","            if i == len(inp)-1:\n","                new_bits = new_int_bottom_bits_inc\n","            else:\n","                new_bits = new_int_top_bits_inc[:num_bits_encoded]\n","            message += new_bits\n","\n","            new_int_bottom_bits = new_int_bottom_bits_inc[num_bits_encoded:] + [0]*num_bits_encoded\n","            new_int_top_bits = new_int_top_bits_inc[num_bits_encoded:] + [1]*num_bits_encoded\n","\n","            cur_interval[0] = bits2int(reversed(new_int_bottom_bits))\n","            cur_interval[1] = bits2int(reversed(new_int_top_bits))+1 # +1 here because upper bound is exclusive\n","            \n","            # Update history with new token\n","            prev = torch.tensor([inp[i]], device=device, dtype=torch.long)\n","            i += 1\n","    \n","    return message\n","\n","model_name = 'gpt2-medium'\n","device = 'cuda'\n","\n","enc, model = get_model(model_name=model_name, device=device)\n","\n","#@title\n","def encode_message(message_str, context, key, nonce):\n","    temp = 0.95\n","    precision = 32\n","    topk = 50000\n","\n","    finish_sent = False\n","    meteor_sort = False\n","    meteor_random = False\n","\n","    # First encode message to uniform bits, without any context\n","    # (not essential this is arithmetic vs ascii, but it's more efficient when the message is natural language)\n","    context_tokens = encode_context(context, enc)\n","    message_ctx = [enc.encoder['<|endoftext|>']]\n","    message_str += '<eos>'\n","    message = decode_arithmetic(\n","        model, enc, message_str, message_ctx, precision=40, topk=60000, device=device)\n","\n","    # Next encode bits into cover text, using arbitrary context\n","    Hq = 0\n","    out, nll, kl, words_per_bit, Hq = encode_meteor(model, enc, message, context_tokens, temp=temp, finish_sent=finish_sent,\n","                                                    precision=precision, topk=topk, device=device, is_sort=meteor_sort, randomize_key=meteor_random, input_key=key, input_nonce=nonce)\n","    text = enc.decode(out)\n","\n","    print(\"=\"*40 + \" Encoding \" + \"=\"*40)\n","    print(text)\n","    print('=> ppl: %0.2f, kl: %0.3f, words/bit: %0.2f, bits/word: %0.2f, entropy: %.2f' %\n","          (math.exp(nll), kl, words_per_bit, 1/words_per_bit, Hq/0.69315))\n","    print(\"=\" * 90)\n","\n","    stats = {\n","        \"ppl\": math.exp(nll),\n","        \"kl\": kl,\n","        \"wordsbit\": words_per_bit,\n","        \"entropy\": Hq/0.69315\n","    }\n","    return text, stats\n","\n","\n","def decode_message(text, context, key, nonce):\n","    temp = 0.95 \n","    precision = 32\n","    topk = 50000\n","\n","    meteor_sort = False\n","\n","    # First encode message to uniform bits, without any context\n","    # (not essential this is arithmetic vs ascii, but it's more efficient when the message is natural language)\n","    context_tokens = encode_context(context, enc)\n","    message_ctx = [enc.encoder['<|endoftext|>']]\n","\n","    message_rec = decode_meteor(model, enc, text, context_tokens, temp=temp,\n","                                precision=precision, topk=topk, device=device, is_sort=meteor_sort, input_key=key, input_nonce=nonce)\n","\n","    reconst = encode_arithmetic(\n","        model, enc, message_rec, message_ctx, precision=40, topk=60000, device=device)\n","    reconst = enc.decode(reconst[0])\n","\n","    print(\"=\"*40 + \" Recovered Message \" + \"=\"*40)\n","    print(reconst[:-5])\n","    print(\"=\" * 99)\n","\n","    # Remove <eos>\n","    return reconst[:-5]"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"b0w7H22MiiDD"},"source":["## Contexts\n","\n","A context is the initial state of the GPT-2 algorithm, prior to running any sampling. This impacts the topic of the output of the model. Use the following form field, `chosen-context`, to select one. We have provided three example contexts:\n","\n","* \"Washington received his initial military training and command with the Virginia Regiment during the French and Indian War. He was later elected to the Virginia House of Burgesses and was named a delegate to the Continental Congress, where he was appointed Commanding General of the nation's Continental Army. Washington led American forces, allied with France, in the defeat of the British at Yorktown. Once victory for the United States was in hand in 1783, Washington resigned his commission.\"\n","\n","* \"The Alvarez hypothesis posits that the mass extinction of the dinosaurs and many other living things during the Cretaceous-Paleogene extinction event was caused by the impact of a large asteroid on the Earth. Prior to 2013, it was commonly cited as having happened about 65 million years ago, but Renne and colleagues (2013) gave an updated value of 66 million years. Evidence indicates that the asteroid fell in the Yucatan Peninsula, at Chicxulub, Mexico. The hypothesis is named after the father-and-son team of scientists Luis and Walter Alvarez, who first suggested it in 1980. Shortly afterwards, and independently, the same was suggested by Dutch paleontologist Jan Smit.\"\n","\n","* \"Despite a long history of research and wide-spread applications to censorship resistant systems, practical steganographic systems capable of embedding messages into realistic communication distributions, like text, do not exist.\"\n","\n","Feel free to edit and add your own!"]},{"cell_type":"code","metadata":{"id":"7bNcvb4gdVgJ"},"source":["#@title  { run: \"auto\", display-mode: \"form\" }\n","chosen_context = \"Despite a long history of research and wide-spread applications to censorship resistant systems, practical steganographic systems capable of embedding messages into realistic communication distributions, like text, do not exist.\" #@param [\"Washington received his initial military training and command with the Virginia Regiment during the French and Indian War. He was later elected to the Virginia House of Burgesses and was named a delegate to the Continental Congress, where he was appointed Commanding General of the nation's Continental Army. Washington led American forces, allied with France, in the defeat of the British at Yorktown. Once victory for the United States was in hand in 1783, Washington resigned his commission.\", \"The Alvarez hypothesis posits that the mass extinction of the dinosaurs and many other living things during the Cretaceous-Paleogene extinction event was caused by the impact of a large asteroid on the Earth. Prior to 2013, it was commonly cited as having happened about 65 million years ago, but Renne and colleagues (2013) gave an updated value of 66 million years. Evidence indicates that the asteroid fell in the Yucatan Peninsula, at Chicxulub, Mexico. The hypothesis is named after the father-and-son team of scientists Luis and Walter Alvarez, who first suggested it in 1980. Shortly afterwards, and independently, the same was suggested by Dutch paleontologist Jan Smit.\", \"Despite a long history of research and wide-spread applications to censorship resistant systems, practical steganographic systems capable of embedding messages into realistic communication distributions, like text, do not exist.\"] {allow-input: true}\n","chosen_context += \"\\n\\n\"  # to add a little spacing"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"vJTH2BtbrZh_"},"source":["## Message\n","\n","Time to set up a message! Use the `message_text` field below to input the message you want to encode. "]},{"cell_type":"code","metadata":{"id":"tftXoMOBsSQ3"},"source":["#@title  { run: \"auto\", display-mode: \"form\" }\n","message_text = \"sample text\" #@param {type:\"string\"}\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"4R_lCTNny60m"},"source":["## Running\n","\n","Run the cell below (using the Play icon to the left of \"Run me!\") to start the Meteor encoder. You'll see the generated stegotext which contains your `message_text` under your `chosen_context`. The system will then decode the stegotext, and you should see your original message as output. \n","\n","Note that, due to issues with the GPT-2 algorithm interface, you sometimes may see extra output from a decoded stegotext. This does not impact the underlying security of the scheme."]},{"cell_type":"code","metadata":{"id":"MmzzrhXa57Vn","cellView":"form"},"source":["#@title Run me!\n","#@markdown Make sure to re-run this cell if you change the parameters above.\n","x = encode_message(message_text, chosen_context, b'\\x03'*64, b'\\x01'*64)\n","y = decode_message(x[0], chosen_context, b'\\x03'*64, b'\\x01'*64)"],"execution_count":null,"outputs":[]}]}
  6. spacelab-ccny/rtbb spacelab-ccny/rtbb Public

    Root the (Ballot) Box: Designing Security Engineering Courses with E-Voting

    C 1 6