Veiled Xor

June 11, 2025 Bi0s CTF 2025 #crypto #GMP

Link to the challenge

Challenge Overview

This challenge involves a RSA based encryption scheme in which we are given $n = p \times q$ and $p \oplus q[::-1]$. Where $q[::-1]$ is the bit-reversed version of $q$. The goal is to recover $p$ and $q$ to be able to calulate the private key and decrypt the given ciphertext $c$.

Approach

We are going to perform a bit by bit recovery of $p$ and $q$.

The idea is simple, at each iteration we will

  1. Branch on the MSB and LSB of $p$ and $q$.
  2. Compute the corresponding $p \times q$ and $p \oplus q[::-1]$.
  3. Check if the computed values are compatible with the given values.

5-bit width example

Let’s consider a simple problem where $p$ and $q$ are 5 bits long. Let’s suppose we have

Variable binary valuebase 10 value
$p$1101127
$q$1001119
$q[::-1]$1101025
$p \times q$01000000001513
$p \oplus q[::-1]$000102

The only information we know are the two last lines.

We will start by supposing that p and q have the following values:

Variable binary value
$p$a???i
$q$b???j
$q[::-1]$j???b
$p \oplus q[::-1] $$a\oplus j$ ??? $i\oplus b$
$p\times q$$(a \times b)$ ??????? $(i \times j \pmod 2)$

The easy conditions to figure out and to understand are those of the XOR and the LSB of $n = p \times q$. And we know that $a \times b$ cannot be bigger than the two most significant bits of $n$. We have

$$\left \lbrace \begin{align*} i \oplus b & = 0 \ a \oplus j & = 0 \ i \times j & \equiv 1 \pmod 2 \ a \times b & \leq n[:2] \ \end{align*} \right.$$

The tricky part is figuring out the condition on the MSB of $n = p \times q$, as the result depends of the carry bit of mulitplication. To try and figure out a condition I wanted to maximise the distance between the MSB of $p \times q$ and the MSB of $n$.

For instance the largest value we can get for the two last bits of $n$ is $11$ which can be obtained if all the bits of $p$ and $q$ are set to 1, thanks to the carry propagation. But this is not obtainable when we only have information on the MSB of $p$ and $q$. And so the largest value we can get for these two last bits is $01$. (We suppose that the unknown bits are set to $0$ in the implementation). In that case we have a difference of $2$.

If we now consider that we are at our $k$-th iteration, in a similar manner I considered that at each step the maximal distance that can is acceptable between the $k$ most significant bits of $n$ and the $k+1$ most significant bits of $p \times q$ is $k + 1$.

Proof of the carry propagation property

In order to prove this property let’s the multiplication with the highest carry on each column, this is obtained when $p$ and $q$ two $l$ length numbers consisting of $l$ $1$’s. In that case the carry is incremented by $1$ for the first $l+2$ bits. and then decreases for the last $l-1$ bits. This can be proven by recursion.

Thus when considering the $k$-th most significant bit of $p \times q$ we have at most a carry of $k$.

If you want to see what it looks like in an example, let’s consider the following with 5 bits. Now let’s calculate $p \times q$ as if we were back in elementary school.

p =         11111
q =         11111
-----------------
=           11111
+          111110
+         1111100
+        11111000
+       111110000
-----------------
sum     123454321 in base 10
carry  1234432100 in base 10
-----------------
result 1111000001 in base 2

Generalisation to the $k$-th step

Let’s suppose the $k$ most significant and least significant bits of $p$ and $q$. Let’s call these hypothetic solutions $p’$ and $q’$. Let’s call $m$ the mask which is $1…10…01…1$ with $k$ bits set to 1 at the beginning and at the end, and $1024 - 2k$ bits set to 0 in the middle. Let’s call $x$ the reversed xor of $p$ and $q$, i.e. $x = p \oplus q[::-1]$. Let’s call $n$ the product of $p$ and $q$, i.e. $n = p \times q$. Let’s call $(A)_{\to i}$ the number that is constituted of the $i$ most significant bit of any given $A$ (It’s a right shift of $2048-i$ for a 2048 bit integer).

The conditions they must satisfy are:

$$\left \lbrace \begin{align*} p’ \oplus q’[::-1] \land m &= x \land m\ p’ \times q’ &\equiv n \pmod {2^{k}} \ (p’ \times q’)_{\to k+1} &\leq n_{\to k+1} \ n_{\to k +1} - (p’ \times q’)_{\to k +1} & \leq k+1 \ \end{align*} \right.$$

Now to solve the problem we have a stack of coupes $(p, q)$ that are possible. we unpile them, add 4 cases to the stack (one for each couple of bits we are adding to the most significant and least significant) and check if they satisfy the conditions. We stop when a couple $(p, q)$ satisfies the conditions and we can compute the private key.

Once this couple is found, decryption is straightforward RSA decryption, the private key is computed as follows: $$d = e^{-1} \mod \phi(n)$$

where $e$ is the public exponent and $\phi(n) = (p-1)(q-1)$.

Then we can decrypt the ciphertext $c$ using the formula: $$m = c^{d} \mod n$$

where $c$ is the ciphertext and $d$ is the private key.

Implementation in python

Python is a really bad language when it comes to manipulating bits, but I wanted to draft a quick program to try the idea and see if the complexity was good enough. To represent the bits I used an array of booleans :sweat_smile:.

Implementation in C with GMP and multi-threading

Once I was convinced it was going to work and realised that the complete execution would take about an hour, I decided to rewrite it in C, using gmp to handle the big integers. The resolution took about 15 minutes on my computer. But told myself I could accelerate this even more by multi-threading the elimination of the branches. This is the final implementation which took 2 minutes to solve the challenge.

boolean_solve.py

SIZE = 1024

def int_to_bool_list(num: int, length: int) -> list[bool]:
    binary_str = bin(num)[2:].zfill(length)
    if len(binary_str) > length:
        raise ValueError(f"Number {num} requires more than {length} bits")
    return [bit == '1' for bit in binary_str]

def bool_list_to_int(bool_list: list[bool]) -> int:
    return int(''.join('1' if b else '0' for b in bool_list), 2)

def reverse_bits(bool_list: list[bool]) -> list[bool]:
    return bool_list[::-1]

def xor_bool_lists(a: list[bool], b: list[bool]) -> list[bool]:
    assert len(a) == len(b), f"Lists must be same length: {len(a)} != {len(b)}"
    return [x ^ y for x, y in zip(a, b)]

def multiply_bool_lists(a: list[bool], b: list[bool]) -> list[bool]:
    """
    Multiply two SIZE-bit numbers represented as boolean lists.
    Returns a 2*SIZE-bit result.
    """
    assert len(a) == SIZE, f"First operand must be SIZE bits, got {len(a)}"
    assert len(b) == SIZE, f"Second operand must be SIZE bits, got {len(b)}"

    int_a = bool_list_to_int(a)
    int_b = bool_list_to_int(b)
    result = int_a * int_b

    return int_to_bool_list(result, 2 * SIZE)

def check_constraints(p: list[bool], q: list[bool], n: list[bool], veil_xor: list[bool], mask_size: int) -> bool:
    assert len(p) == SIZE, f"p must be SIZE bits, got {len(p)}"
    assert len(q) == SIZE, f"q must be SIZE bits, got {len(q)}"
    assert len(n) == 2 * SIZE, f"n must be 2*SIZE bits, got {len(n)}"
    assert len(veil_xor) == SIZE, f"veil_xor must be SIZE bits, got {len(veil_xor)}"

    p_masked = [False] * SIZE
    q_masked = [False] * SIZE

    for i in range(mask_size):
        p_masked[i] = p[i]
        q_masked[i] = q[i]

    for i in range(mask_size):
        p_masked[SIZE - mask_size + i] = p[SIZE - mask_size + i]
        q_masked[SIZE - mask_size + i] = q[SIZE - mask_size + i]

    product = multiply_bool_lists(p_masked, q_masked)

    # Check MSB bits of product
    top_n = bool_list_to_int(n[: mask_size + 1])
    top_product = bool_list_to_int(product[: mask_size + 1])
    if top_n - top_product > mask_size + 1:
        return False
    if top_product > top_n:
        return False

    # Check LSB bits of product
    for i in range(mask_size):
        if product[2 * SIZE - mask_size + i] != n[2 * SIZE - mask_size + i]:
            return False

    # Check veil XOR constraint for the masked bits
    q_reversed = reverse_bits(q_masked)
    p_xor_q_rev = xor_bool_lists(p_masked, q_reversed)

    # Check MSB bits of XOR
    for i in range(mask_size):
        if p_xor_q_rev[i] != veil_xor[i]:
            return False

    # Check LSB bits of XOR
    for i in range(mask_size):
        if p_xor_q_rev[SIZE - mask_size + i] != veil_xor[SIZE - mask_size + i]:
            return False

    return True

def generate_candidates(
    p_template: list[bool], q_template: list[bool], mask_size: int
) -> list[tuple[list[bool], list[bool]]]:
    """
    Generate new candidates by setting the next MSB and LSB bits.
    """
    assert len(p_template) == SIZE, f"p_template must be SIZE bits, got {len(p_template)}"
    assert len(q_template) == SIZE, f"q_template must be SIZE bits, got {len(q_template)}"

    candidates = []

    # Try all combinations of the new MSB and LSB bits for both p and q
    for p_msb in [False, True]:
        for p_lsb in [False, True]:
            for q_msb in [False, True]:
                for q_lsb in [False, True]:
                    new_p = p_template.copy()
                    new_q = q_template.copy()

                    # Set the new MSB bit
                    new_p[mask_size - 1] = p_msb
                    new_q[mask_size - 1] = q_msb

                    # Set the new LSB bit
                    new_p[SIZE - mask_size] = p_lsb
                    new_q[SIZE - mask_size] = q_lsb

                    candidates.append((new_p, new_q))

    return candidates

def progressive_factorization(n: int, veil_xor: int) -> tuple[int, int]:
    n_bits = int_to_bool_list(n, 2 * SIZE)
    veil_xor_bits = int_to_bool_list(veil_xor, SIZE)

    print(f"n has {len(n_bits)} bits")
    print(f"veil_xor has {len(veil_xor_bits)} bits")

    valid_candidates = []

    for i1 in [False, True]:
        for i2 in [False, True]:
            for i3 in [False, True]:
                for i4 in [False, True]:
                    # Create SIZE-bit templates with only the first and last bits set
                    p_template = [False] * SIZE
                    q_template = [False] * SIZE

                    p_template[0] = i1
                    p_template[-1] = i2
                    q_template[0] = i3
                    q_template[-1] = i4

                    if check_constraints(p_template, q_template, n_bits, veil_xor_bits, 1):
                        valid_candidates.append((p_template, q_template))

    print(f"Starting with {len(valid_candidates)} valid 1-bit candidates")

    # Progressive extension
    for mask_size in range(2, SIZE // 2 + 1):  # Up to 512 bits from each end
        new_candidates = []

        for p_template, q_template in valid_candidates:
            candidates = generate_candidates(p_template, q_template, mask_size)

            for new_p, new_q in candidates:
                if check_constraints(new_p, new_q, n_bits, veil_xor_bits, mask_size):
                    new_candidates.append((new_p, new_q))

        valid_candidates = new_candidates
        print(f"After {mask_size} bits: {len(valid_candidates)} valid candidates")

        if len(valid_candidates) == 0:
            print("No more valid candidates found")
            break

    if len(valid_candidates) == 0:
        raise ValueError("No valid factors found")

    # Return the first valid solution
    p_final, q_final = valid_candidates[0]

    # Verify lengths one more time
    assert len(p_final) == SIZE, f"Final p must be SIZE bits, got {len(p_final)}"
    assert len(q_final) == SIZE, f"Final q must be SIZE bits, got {len(q_final)}"

    p_int = bool_list_to_int(p_final)
    q_int = bool_list_to_int(q_final)

    return p_int, q_int

def solve_challenge(n: int, c: int, veil_xor: int) -> bytes:
    print("Starting progressive factorization...")
    p, q = progressive_factorization(n, veil_xor)

    print("Found factors:")
    print(f"p = {p}")
    print(f"q = {q}")

    # Verify the solution
    assert p * q == n, f"Invalid factorization: {p} * {q} != {n}"

    # Verify veil XOR constraint
    q_reversed_int = int(bin(q)[2:][::-1], 2)
    assert p ^ q_reversed_int == veil_xor, f"Invalid veil XOR: {p} ^ {q_reversed_int} != {veil_xor}"

    print("Factorization verified!")

    # Compute private exponent
    phi = (p - 1) * (q - 1)
    e = 65537
    d = pow(e, -1, phi)

    # Decrypt
    m = pow(c, d, n)

    # Convert to bytes
    flag = m.to_bytes((m.bit_length() + 7) // 8, 'big')

    return flag

if __name__ == "__main__":
    n = 25650993834245004720946189793874326497984795849338302417110946799293291648040249066481025511053012034073848003478136002015789778483853455736405270138192685004206122168607287667373629714589814547144217162436740164024414206705483947822707673759856022882063396271521077034396144039740088690783163935477234001508676877728359035563304374705319120303835098697559771353065115371216095633826663393222290375210498159025443467666369652776698531368926392564476840557482790175694984871271075976052162527476586777386578254654222259777299785563550342986250558793337690540798983389913689337683350216697595855274995968459458553148267
    c = 7874419222145223100478995004906732383469089972173454594282476506666095078687712494332749473566534625352139353593310707008146533254390514332880136585545606758108380402050369451711762195058199249765633645224407166178729834108159734540770902813439688437621416030538050164358987313607945402928893945400086827254622507315341530235984071126104731692679123171962413857123065243252313290356908958113679070546907527095194888688858140118665219670816655147095649132221436351529029926610142793850463533766705147562234382644751744682744799743855986811769162311342911946128543115444104102909314075691320520722623778914052878038508
    veiled_xor = 26845073698882094013214557201710791833291706601384082712658811014034994099681783926930272036664572532136049856667171349310624166258134687815795133386046337514685147643316723034719743474088423205525505355817639924602251866472741277968741560579392242642848932606998045419509860412262320853772858267058490738386

    flag = solve_challenge(n, c, veiled_xor)
    print("Decrypted flag:", flag.decode('utf-8', errors='ignore'))

solve.c

# include <assert.h>
# include <gmp.h>
# include <pthread.h>
# include <stdio.h>
# include <stdlib.h>
# include <string.h>
# include <unistd.h>

# define SIZE 1024
# define MAX_CANDIDATES 1000000

typedef struct {
  mpz_t p;
  mpz_t q;
} candidate_t;

typedef struct {
  candidate_t *candidates;
  int count;
  int capacity;
} candidate_list_t;

typedef struct {
  int thread_id;
  int start_idx;
  int end_idx;
  candidate_list_t *input_candidates;
  candidate_list_t *output_candidates;
  mpz_t n;
  mpz_t veil_xor;
  int mask_size;
  pthread_mutex_t *output_mutex;
} thread_data_t;

void init_candidate_list(candidate_list_t *list) {
  list->capacity = 10000;
  list->count = 0;
  list->candidates = malloc(list->capacity * sizeof(candidate_t));
  for (int i = 0; i < list->capacity; i++) {
    mpz_init(list->candidates[i].p);
    mpz_init(list->candidates[i].q);
  }
}

void add_candidate(candidate_list_t *list, mpz_t p, mpz_t q) {
  if (list->count >= list->capacity) {
    if (list->capacity >= MAX_CANDIDATES) {
      printf("Reached maximum candidates limit of %d\n", MAX_CANDIDATES);
      return;
    }
    list->capacity *= 2;
    if (list->capacity > MAX_CANDIDATES) list->capacity = MAX_CANDIDATES;
    list->candidates =
      realloc(list->candidates, list->capacity * sizeof(candidate_t));
    for (int i = list->count; i < list->capacity; i++) {
      mpz_init(list->candidates[i].p);
      mpz_init(list->candidates[i].q);
    }
  }

  mpz_set(list->candidates[list->count].p, p);
  mpz_set(list->candidates[list->count].q, q);
  list->count++;
}

void add_candidate_safe(
  candidate_list_t *list, mpz_t p, mpz_t q, pthread_mutex_t *mutex
) {
  pthread_mutex_lock(mutex);

  if (list->count >= list->capacity) {
    if (list->capacity >= MAX_CANDIDATES) {
      printf("Reached maximum candidates limit of %d\n", MAX_CANDIDATES);
      pthread_mutex_unlock(mutex);
      return;
    }
    list->capacity *= 2;
    if (list->capacity > MAX_CANDIDATES) list->capacity = MAX_CANDIDATES;
    list->candidates =
      realloc(list->candidates, list->capacity * sizeof(candidate_t));
    for (int i = list->count; i < list->capacity; i++) {
      mpz_init(list->candidates[i].p);
      mpz_init(list->candidates[i].q);
    }
  }

  mpz_set(list->candidates[list->count].p, p);
  mpz_set(list->candidates[list->count].q, q);
  list->count++;

  pthread_mutex_unlock(mutex);
}

void clear_candidate_list(candidate_list_t *list) {
  for (int i = 0; i < list->capacity; i++) {
    mpz_clear(list->candidates[i].p);
    mpz_clear(list->candidates[i].q);
  }
  free(list->candidates);
  list->count = 0;
}

void reverse_bits(mpz_t result, mpz_t num, int bit_length) {
  mpz_set_ui(result, 0);

  for (int i = 0; i < bit_length; i++)
    if (mpz_tstbit(num, i)) mpz_setbit(result, bit_length - 1 - i);
}

void create_mask(mpz_t mask, int mask_size, int total_bits) {
  mpz_set_ui(mask, 0);

  for (int i = 0; i < mask_size; i++)
    mpz_setbit(mask, total_bits - 1 - i);

  for (int i = 0; i < mask_size; i++)
    mpz_setbit(mask, i);
}

int check_constraints(
  mpz_t p, mpz_t q, mpz_t n, mpz_t veil_xor, int mask_size
) {
  mpz_t p_masked, q_masked, product, q_rev, p_xor_q_rev;
  mpz_t mask, temp, n_masked, veil_masked;

  mpz_init(p_masked);
  mpz_init(q_masked);
  mpz_init(product);
  mpz_init(q_rev);
  mpz_init(p_xor_q_rev);
  mpz_init(mask);
  mpz_init(temp);
  mpz_init(n_masked);
  mpz_init(veil_masked);

  create_mask(mask, mask_size, SIZE);

  mpz_and(p_masked, p, mask);
  mpz_and(q_masked, q, mask);

  // Check multiplication constraint
  mpz_mul(product, p_masked, q_masked);

  // Check LSB bits of product - they must match exactly
  for (int i = 0; i < mask_size; i++)
    if (mpz_tstbit(product, i) != mpz_tstbit(n, i)) goto cleanup_false;

  // Check MSB bits of product with the special constraint
  // Extract top (mask_size + 1) bits of both product and n
  mpz_t top_n, top_product, diff;
  mpz_init(top_n);
  mpz_init(top_product);
  mpz_init(diff);

  // Get top (mask_size + 1) bits of n
  mpz_tdiv_q_2exp(top_n, n, 2 * SIZE - mask_size - 1);
  mpz_tdiv_r_2exp(top_n, top_n, mask_size + 1);

  // Get top (mask_size + 1) bits of product
  mpz_tdiv_q_2exp(top_product, product, 2 * SIZE - mask_size - 1);
  mpz_tdiv_r_2exp(top_product, top_product, mask_size + 1);

  // Check if top_product > top_n
  if (mpz_cmp(top_product, top_n) > 0) {
    mpz_clear(top_n);
    mpz_clear(top_product);
    mpz_clear(diff);
    goto cleanup_false;
  }

  // Check if top_n - top_product > mask_size + 1
  mpz_sub(diff, top_n, top_product);
  if (mpz_cmp_ui(diff, mask_size + 1) > 0) {
    mpz_clear(top_n);
    mpz_clear(top_product);
    mpz_clear(diff);
    goto cleanup_false;
  }

  mpz_clear(top_n);
  mpz_clear(top_product);
  mpz_clear(diff);

  // Check veil XOR constraint
  reverse_bits(q_rev, q_masked, SIZE);
  mpz_xor(p_xor_q_rev, p_masked, q_rev);

  // Apply mask to veil_xor for comparison
  create_mask(temp, mask_size, SIZE);
  mpz_and(veil_masked, veil_xor, temp);
  mpz_and(p_xor_q_rev, p_xor_q_rev, temp);

  if (mpz_cmp(p_xor_q_rev, veil_masked) != 0) goto cleanup_false;

  // Cleanup and return success
  mpz_clear(p_masked);
  mpz_clear(q_masked);
  mpz_clear(product);
  mpz_clear(q_rev);
  mpz_clear(p_xor_q_rev);
  mpz_clear(mask);
  mpz_clear(temp);
  mpz_clear(n_masked);
  mpz_clear(veil_masked);
  return 1;

cleanup_false:
  mpz_clear(p_masked);
  mpz_clear(q_masked);
  mpz_clear(product);
  mpz_clear(q_rev);
  mpz_clear(p_xor_q_rev);
  mpz_clear(mask);
  mpz_clear(temp);
  mpz_clear(n_masked);
  mpz_clear(veil_masked);
  return 0;
}

// Generate new candidates by extending current templates
void generate_candidates(
  candidate_list_t *new_list, mpz_t p_template, mpz_t q_template, int mask_size
) {
  mpz_t new_p, new_q;
  mpz_init(new_p);
  mpz_init(new_q);

  // Try all combinations of new MSB and LSB bits
  for (int p_msb = 0; p_msb <= 1; p_msb++) {
    for (int p_lsb = 0; p_lsb <= 1; p_lsb++) {
      for (int q_msb = 0; q_msb <= 1; q_msb++) {
        for (int q_lsb = 0; q_lsb <= 1; q_lsb++) {
          mpz_set(new_p, p_template);
          mpz_set(new_q, q_template);

          if (p_msb)
            mpz_setbit(new_p, SIZE - mask_size);
          else
            mpz_clrbit(new_p, SIZE - mask_size);

          if (q_msb)
            mpz_setbit(new_q, SIZE - mask_size);
          else
            mpz_clrbit(new_q, SIZE - mask_size);

          if (p_lsb)
            mpz_setbit(new_p, mask_size - 1);
          else
            mpz_clrbit(new_p, mask_size - 1);

          if (q_lsb)
            mpz_setbit(new_q, mask_size - 1);
          else
            mpz_clrbit(new_q, mask_size - 1);

          add_candidate(new_list, new_p, new_q);
        }
      }
    }
  }

  mpz_clear(new_p);
  mpz_clear(new_q);
}

void *thread_worker(void *arg) {
  thread_data_t *data = (thread_data_t *)arg;

  for (int i = data->start_idx;
       i < data->end_idx && i < data->input_candidates->count;
       i++) {
    candidate_list_t temp_candidates;
    init_candidate_list(&temp_candidates);

    generate_candidates(
      &temp_candidates,
      data->input_candidates->candidates[i].p,
      data->input_candidates->candidates[i].q,
      data->mask_size
    );

    for (int j = 0; j < temp_candidates.count; j++) {
      if (check_constraints(
            temp_candidates.candidates[j].p,
            temp_candidates.candidates[j].q,
            data->n,
            data->veil_xor,
            data->mask_size
          )) {
        add_candidate_safe(
          data->output_candidates,
          temp_candidates.candidates[j].p,
          temp_candidates.candidates[j].q,
          data->output_mutex
        );
      }
    }

    clear_candidate_list(&temp_candidates);
  }

  return NULL;
}

// Final verification function to check complete factorization
int verify_complete_factorization(
  candidate_list_t *candidates,
  mpz_t n,
  mpz_t veil_xor,
  mpz_t p_result,
  mpz_t q_result
) {
  printf(
    "Performing final verification of %d candidates...\n", candidates->count
  );

  mpz_t product, q_rev, p_xor_q_rev;
  mpz_init(product);
  mpz_init(q_rev);
  mpz_init(p_xor_q_rev);

  for (int i = 0; i < candidates->count; i++) {
    // Check if p * q = n
    mpz_mul(product, candidates->candidates[i].p, candidates->candidates[i].q);

    if (mpz_cmp(product, n) == 0) {
      // Also verify the veil XOR constraint
      reverse_bits(q_rev, candidates->candidates[i].q, SIZE);
      mpz_xor(p_xor_q_rev, candidates->candidates[i].p, q_rev);

      if (mpz_cmp(p_xor_q_rev, veil_xor) == 0) {
        printf("Found valid complete factorization!\n");
        printf(
          "Candidate %d satisfies both p*q=n and p^reverse(q)=veil_xor\n", i
        );

        mpz_set(p_result, candidates->candidates[i].p);
        mpz_set(q_result, candidates->candidates[i].q);

        mpz_clear(product);
        mpz_clear(q_rev);
        mpz_clear(p_xor_q_rev);
        return 1;
      }
    }
  }

  printf("No candidate satisfies the complete factorization!\n");
  mpz_clear(product);
  mpz_clear(q_rev);
  mpz_clear(p_xor_q_rev);
  return 0;
}

// Progressive factorization algorithm
int progressive_factorization(
  mpz_t p_result, mpz_t q_result, mpz_t n, mpz_t veil_xor
) {
  candidate_list_t valid_candidates, new_candidates;
  mpz_t p_template, q_template;

  init_candidate_list(&valid_candidates);
  init_candidate_list(&new_candidates);
  mpz_init(p_template);
  mpz_init(q_template);

  printf("Starting progressive factorization...\n");
  printf("n has %zu bits\n", mpz_sizeinbase(n, 2));
  printf("veil_xor has %zu bits\n", mpz_sizeinbase(veil_xor, 2));

  for (int i1 = 0; i1 <= 1; i1++) {
    for (int i2 = 0; i2 <= 1; i2++) {
      for (int i3 = 0; i3 <= 1; i3++) {
        for (int i4 = 0; i4 <= 1; i4++) {
          mpz_set_ui(p_template, 0);
          mpz_set_ui(q_template, 0);

          if (i1) mpz_setbit(p_template, SIZE - 1);
          if (i2) mpz_setbit(p_template, 0);
          if (i3) mpz_setbit(q_template, SIZE - 1);
          if (i4) mpz_setbit(q_template, 0);

          if (check_constraints(p_template, q_template, n, veil_xor, 1))
            add_candidate(&valid_candidates, p_template, q_template);
        }
      }
    }
  }

  printf("Starting with %d valid 1-bit candidates\n", valid_candidates.count);

  int num_threads = sysconf(_SC_NPROCESSORS_ONLN);
  printf("Using %d threads\n", num_threads);

  for (int mask_size = 2; mask_size <= SIZE / 2; mask_size++) {
    new_candidates.count = 0;

    if (valid_candidates.count == 0) {
      printf("No more valid candidates found\n");
      break;
    }

    // Don't create more threads than we have candidates
    int actual_threads = (valid_candidates.count < num_threads)
      ? valid_candidates.count
      : num_threads;

    // Create thread data structures
    pthread_t *threads = malloc(actual_threads * sizeof(pthread_t));
    thread_data_t *thread_data = malloc(actual_threads * sizeof(thread_data_t));
    pthread_mutex_t output_mutex = PTHREAD_MUTEX_INITIALIZER;

    // Calculate work distribution
    int candidates_per_thread = valid_candidates.count / actual_threads;
    int remaining_candidates = valid_candidates.count % actual_threads;

    // Launch threads
    int current_start = 0;
    for (int t = 0; t < actual_threads; t++) {
      thread_data[t].thread_id = t;
      thread_data[t].start_idx = current_start;

      // Add one extra candidate to first 'remaining_candidates' threads
      int candidates_for_this_thread =
        candidates_per_thread + (t < remaining_candidates ? 1 : 0);
      thread_data[t].end_idx = current_start + candidates_for_this_thread;
      current_start = thread_data[t].end_idx;

      thread_data[t].input_candidates = &valid_candidates;
      thread_data[t].output_candidates = &new_candidates;
      thread_data[t].mask_size = mask_size;
      thread_data[t].output_mutex = &output_mutex;

      // Debug output for work distribution
      if (mask_size <= 4) { // Only print for early iterations
        printf(
          "Thread %d: processing candidates %d to %d\n",
          t,
          thread_data[t].start_idx,
          thread_data[t].end_idx - 1
        );
      }

      // Initialize mpz_t values for this thread
      mpz_init_set(thread_data[t].n, n);
      mpz_init_set(thread_data[t].veil_xor, veil_xor);

      pthread_create(&threads[t], NULL, thread_worker, &thread_data[t]);
    }

    // Wait for all threads to complete
    for (int t = 0; t < actual_threads; t++) {
      pthread_join(threads[t], NULL);
      mpz_clear(thread_data[t].n);
      mpz_clear(thread_data[t].veil_xor);
    }

    // Cleanup thread resources
    pthread_mutex_destroy(&output_mutex);
    free(threads);
    free(thread_data);

    candidate_list_t temp = valid_candidates;
    valid_candidates = new_candidates;
    new_candidates = temp;
    new_candidates.count = 0;

    printf(
      "After %d bits: %d valid candidates\n", mask_size, valid_candidates.count
    );

    if (valid_candidates.count == 0) {
      printf("No more valid candidates found\n");
      break;
    }

    // Limit candidates to prevent memory explosion
    if (valid_candidates.count > 500000) {
      printf("Too many candidates, keeping first 500000\n");
      valid_candidates.count = 500000;
    }

    // When we reach 512 bits (half of SIZE), perform final verification
    if (mask_size == SIZE / 2) {
      printf(
        "Reached %d bits - performing final complete factorization check...\n",
        mask_size
      );
      if (verify_complete_factorization(
            &valid_candidates, n, veil_xor, p_result, q_result
          )) {
        // Found a valid complete factorization
        clear_candidate_list(&valid_candidates);
        clear_candidate_list(&new_candidates);
        mpz_clear(p_template);
        mpz_clear(q_template);
        return 1;
      } else {
        printf("No complete factorization found at 512 bits. Continuing...\n");
      }
    }
  }

  if (valid_candidates.count == 0) {
    printf("No valid factors found\n");
    clear_candidate_list(&valid_candidates);
    clear_candidate_list(&new_candidates);
    mpz_clear(p_template);
    mpz_clear(q_template);
    return 0;
  }

  // Return the first valid solution
  mpz_set(p_result, valid_candidates.candidates[0].p);
  mpz_set(q_result, valid_candidates.candidates[0].q);

  clear_candidate_list(&valid_candidates);
  clear_candidate_list(&new_candidates);
  mpz_clear(p_template);
  mpz_clear(q_template);
  return 1;
}

// Parse challenge data from out.txt
int parse_challenge_data(mpz_t n, mpz_t c, mpz_t veil_xor) {
  FILE *file = fopen("out.txt", "r");
  if (!file) {
    printf("Error: Could not open out.txt\n");
    return 0;
  }

  char line[4096];
  int n_found = 0, c_found = 0, veil_found = 0;

  while (fgets(line, sizeof(line), file)) {
    if (strncmp(line, "n :", 3) == 0) {
      if (mpz_set_str(n, line + 4, 10) == 0) n_found = 1;
    } else if (strncmp(line, "c :", 3) == 0) {
      if (mpz_set_str(c, line + 4, 10) == 0) c_found = 1;
    } else if (strncmp(line, "Veil XOR:", 9) == 0) {
      if (mpz_set_str(veil_xor, line + 10, 10) == 0) veil_found = 1;
    }
  }

  fclose(file);
  return n_found && c_found && veil_found;
}

// Solve the RSA challenge
int solve_challenge(mpz_t n, mpz_t c, mpz_t veil_xor) {
  mpz_t p, q, phi, d, e, m;
  mpz_init(p);
  mpz_init(q);
  mpz_init(phi);
  mpz_init(d);
  mpz_init(e);
  mpz_init(m);

  printf("Starting progressive factorization...\n");

  if (!progressive_factorization(p, q, n, veil_xor)) {
    printf("Failed to find factors\n");
    goto cleanup;
  }

  printf("Found factors:\n");
  gmp_printf("p = %Zd\n", p);
  gmp_printf("q = %Zd\n", q);

  // Verify the solution
  mpz_t temp, q_rev;
  mpz_init(temp);
  mpz_init(q_rev);

  mpz_mul(temp, p, q);
  if (mpz_cmp(temp, n) != 0) {
    printf("Invalid factorization: p * q != n\n");
    mpz_clear(temp);
    mpz_clear(q_rev);
    goto cleanup;
  }

  reverse_bits(q_rev, q, SIZE);
  mpz_xor(temp, p, q_rev);
  if (mpz_cmp(temp, veil_xor) != 0) {
    printf("Invalid veil XOR: p ^ reverse(q) != veil_xor\n");
    mpz_clear(temp);
    mpz_clear(q_rev);
    goto cleanup;
  }

  printf("Factorization verified!\n");

  // Compute private exponent
  mpz_sub_ui(temp, p, 1);
  mpz_sub_ui(q_rev, q, 1);
  mpz_mul(phi, temp, q_rev);

  mpz_set_ui(e, 65537);
  if (mpz_invert(d, e, phi) == 0) {
    printf("Failed to compute private exponent\n");
    mpz_clear(temp);
    mpz_clear(q_rev);
    goto cleanup;
  }

  // Decrypt
  mpz_powm(m, c, d, n);

  printf("Decrypted message (hex): ");
  mpz_out_str(stdout, 16, m);
  printf("\n");

  // Try to convert to ASCII if reasonable size
  if (mpz_sizeinbase(m, 2) <= 1024) {
    char *flag_str = mpz_get_str(NULL, 16, m);
    printf("Flag (hex): %s\n", flag_str);

    // Convert hex to ASCII
    int len = strlen(flag_str);
    if (len % 2 == 0) {
      printf("Flag (ASCII): ");
      for (int i = 0; i < len; i += 2) {
        char hex_byte[3] = {flag_str[i], flag_str[i + 1], '\0'};
        int byte_val = strtol(hex_byte, NULL, 16);
        if (byte_val >= 32 && byte_val <= 126)
          printf("%c", byte_val);
        else
          printf("\\x%02x", byte_val);
      }
      printf("\n");
    }
    free(flag_str);
  }

  mpz_clear(temp);
  mpz_clear(q_rev);
  mpz_clear(p);
  mpz_clear(q);
  mpz_clear(phi);
  mpz_clear(d);
  mpz_clear(e);
  mpz_clear(m);
  return 1;

cleanup:
  mpz_clear(p);
  mpz_clear(q);
  mpz_clear(phi);
  mpz_clear(d);
  mpz_clear(e);
  mpz_clear(m);
  return 0;
}

int main() {
  mpz_t n, c, veil_xor;
  mpz_init(n);
  mpz_init(c);
  mpz_init(veil_xor);

  printf("=== Veil XOR RSA Solver ===\n\n");

  if (!parse_challenge_data(n, c, veil_xor)) {
    printf("Failed to parse challenge data from out.txt\n");
    goto cleanup;
  }

  printf("Challenge data loaded successfully\n");
  gmp_printf("n = %Zd\n", n);
  gmp_printf("c = %Zd\n", c);
  gmp_printf("veil_xor = %Zd\n", veil_xor);
  printf("\n");

  if (!solve_challenge(n, c, veil_xor)) {
    printf("Failed to solve challenge\n");
    goto cleanup;
  }

  mpz_clear(n);
  mpz_clear(c);
  mpz_clear(veil_xor);
  return 0;

cleanup:
  mpz_clear(n);
  mpz_clear(c);
  mpz_clear(veil_xor);
  return 1;
}