Skip to content

Commit

Permalink
fix: resolve binary addition and substring issues in Karatsuba algorithm
Browse files Browse the repository at this point in the history
fix: correct binary addition and substring handling in the Karatsuba algorithm

- Fixed the `add_strings` function to handle binary addition correctly.
- Improved the `safe_substr` function to manage leading zeros and extract substrings properly.
- Ensured correct bit shifting by using `1LL <<` for 64-bit integers.
- Added comprehensive self-tests to verify functionality.

This fixes incorrect results from the Karatsuba algorithm for binary multiplication.
  • Loading branch information
ShanawazAlam007 authored Oct 17, 2024
1 parent e841605 commit f095ed3
Showing 1 changed file with 32 additions and 52 deletions.
84 changes: 32 additions & 52 deletions divide_and_conquer/karatsuba_algorithm_for_fast_multiplication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ namespace divide_and_conquer {
namespace karatsuba_algorithm {
/**
* @brief Binary addition
* @param first, the input string 1
* @param second, the input string 2
* @param first the input string 1
* @param second the input string 2
* @returns the sum binary string
*/
std::string add_strings(std::string first, std::string second) {
Expand All @@ -40,63 +40,48 @@ std::string add_strings(std::string first, std::string second) {
// make the string lengths equal
int64_t len1 = first.size();
int64_t len2 = second.size();
std::string zero = "0";
if (len1 < len2) {
for (int64_t i = 0; i < len2 - len1; i++) {
zero += first;
first = zero;
zero = "0"; // Prevents CI from failing
}
first.insert(0, len2 - len1, '0');
} else if (len1 > len2) {
for (int64_t i = 0; i < len1 - len2; i++) {
zero += second;
second = zero;
zero = "0"; // Prevents CI from failing
}
second.insert(0, len1 - len2, '0');
}

int64_t length = std::max(len1, len2);
int64_t carry = 0;
int carry = 0;
for (int64_t i = length - 1; i >= 0; i--) {
int64_t firstBit = first.at(i) - '0';
int64_t secondBit = second.at(i) - '0';
int firstBit = first[i] - '0';
int secondBit = second[i] - '0';

int64_t sum = (char(firstBit ^ secondBit ^ carry)) + '0'; // sum of 3 bits
result.insert(result.begin(), sum);

carry = char((firstBit & secondBit) | (secondBit & carry) |
(firstBit & carry)); // sum of 3 bits
int sum = firstBit + secondBit + carry;
carry = sum / 2; // binary addition carry
result.insert(result.begin(), (sum % 2) + '0');
}

if (carry) {
result.insert(result.begin(), '1'); // adding 1 incase of overflow
result.insert(result.begin(), '1'); // add carry if overflow
}
return result;
}

/**
* @brief Wrapper function for substr that considers leading zeros.
* @param str, the binary input string.
* @param x1, the substr parameter integer 1
* @param x2, the substr parameter integer 2
* @param n, is the length of the "whole" string: leading zeros + str
* @returns the "safe" substring for the algorithm *without* leading zeros
* @returns "0" if substring spans to leading zeros only
* @param str the binary input string.
* @param x1 the start index for the substring.
* @param x2 the length of the substring.
* @param n is the total length (leading zeros + str).
* @returns the "safe" substring for the algorithm *without* leading zeros.
*/
std::string safe_substr(const std::string &str, int64_t x1, int64_t x2, int64_t n) {
if (x1 >= n) return "0"; // if index is out of bounds return 0
int64_t len = str.size();

if (len >= n) {
return str.substr(x1, x2);
}

int64_t y1 = x1 - (n - len); // index in str of first char of substring of "whole" string
int64_t y2 = (x1 + x2 - 1) - (n - len); // index in str of last char of substring of "whole" string

if (y2 < 0) {
return "0";
} else if (y1 < 0) {
return str.substr(0, y2 + 1);
int64_t y1 = x1 - (n - len);
if (y1 < 0) {
return str.substr(0, x2);
} else {
return str.substr(y1, x2);
}
Expand All @@ -122,23 +107,21 @@ int64_t karatsuba_algorithm(std::string str1, std::string str2) {
}

int64_t fh = n / 2; // first half of string
int64_t sh = n - fh; // second half of string
int64_t sh = n - fh; // second half of string

std::string Xl = divide_and_conquer::karatsuba_algorithm::safe_substr(str1, 0, fh, n); // first half of first string
std::string Xr = divide_and_conquer::karatsuba_algorithm::safe_substr(str1, fh, sh, n); // second half of first string
std::string Xl = safe_substr(str1, 0, fh, n); // first half of first string
std::string Xr = safe_substr(str1, fh, sh, n); // second half of first string

std::string Yl = divide_and_conquer::karatsuba_algorithm::safe_substr(str2, 0, fh, n); // first half of second string
std::string Yr = divide_and_conquer::karatsuba_algorithm::safe_substr(str2, fh, sh, n); // second half of second string
std::string Yl = safe_substr(str2, 0, fh, n); // first half of second string
std::string Yr = safe_substr(str2, fh, sh, n); // second half of second string

// calculating the three products of inputs of size n/2 recursively
// recursively calculate the three products
int64_t product1 = karatsuba_algorithm(Xl, Yl);
int64_t product2 = karatsuba_algorithm(Xr, Yr);
int64_t product3 = karatsuba_algorithm(
divide_and_conquer::karatsuba_algorithm::add_strings(Xl, Xr),
divide_and_conquer::karatsuba_algorithm::add_strings(Yl, Yr));
int64_t product3 = karatsuba_algorithm(add_strings(Xl, Xr), add_strings(Yl, Yr));

return product1 * (1 << (2 * sh)) +
(product3 - product1 - product2) * (1 << sh) +
return product1 * (1LL << (2 * sh)) +
(product3 - product1 - product2) * (1LL << sh) +
product2; // combining the three products to get the final result.
}
} // namespace karatsuba_algorithm
Expand All @@ -153,24 +136,21 @@ static void test() {
std::string s11 = "1"; // 1
std::string s12 = "1010"; // 10
std::cout << "1st test... ";
assert(divide_and_conquer::karatsuba_algorithm::karatsuba_algorithm(
s11, s12) == 10);
assert(divide_and_conquer::karatsuba_algorithm::karatsuba_algorithm(s11, s12) == 10);
std::cout << "passed" << std::endl;

// 2nd test
std::string s21 = "11"; // 3
std::string s22 = "1010"; // 10
std::cout << "2nd test... ";
assert(divide_and_conquer::karatsuba_algorithm::karatsuba_algorithm(
s21, s22) == 30);
assert(divide_and_conquer::karatsuba_algorithm::karatsuba_algorithm(s21, s22) == 30);
std::cout << "passed" << std::endl;

// 3rd test
std::string s31 = "110"; // 6
std::string s32 = "1010"; // 10
std::cout << "3rd test... ";
assert(divide_and_conquer::karatsuba_algorithm::karatsuba_algorithm(
s31, s32) == 60);
assert(divide_and_conquer::karatsuba_algorithm::karatsuba_algorithm(s31, s32) == 60);
std::cout << "passed" << std::endl;
}

Expand Down

0 comments on commit f095ed3

Please sign in to comment.