Skip to content

Commit

Permalink
fix config (zilliztech#903)
Browse files Browse the repository at this point in the history
Signed-off-by: foxspy <[email protected]>
  • Loading branch information
foxspy authored Oct 18, 2024
1 parent 2a13a9d commit 616f5e3
Show file tree
Hide file tree
Showing 12 changed files with 163 additions and 36 deletions.
87 changes: 82 additions & 5 deletions include/knowhere/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ typedef nlohmann::json Json;
#define CFG_INT std::optional<int32_t>
#endif

#ifndef CFG_INT64
#define CFG_INT64 std::optional<int64_t>
#endif

#ifndef CFG_STRING
#define CFG_STRING std::optional<std::string>
#endif
Expand Down Expand Up @@ -140,6 +144,31 @@ struct Entry<CFG_INT> {
bool allow_empty_without_default = false;
};

template <>
struct Entry<CFG_INT64> {
explicit Entry(CFG_INT64* v) {
val = v;
default_val = std::nullopt;
type = 0x0;
range = std::nullopt;
desc = std::nullopt;
}
Entry() {
val = nullptr;
default_val = std::nullopt;
type = 0x0;
range = std::nullopt;
desc = std::nullopt;
}

CFG_INT64* val;
std::optional<CFG_INT64::value_type> default_val;
uint32_t type;
std::optional<std::pair<CFG_INT64::value_type, CFG_INT64::value_type>> range;
std::optional<std::string> desc;
bool allow_empty_without_default = false;
};

template <>
struct Entry<CFG_BOOL> {
explicit Entry(CFG_BOOL* v) {
Expand Down Expand Up @@ -317,12 +346,12 @@ class Config {
}
if (!json[it.first].is_number_integer()) {
std::string msg = "Type conflict in json: param '" + it.first + "' (" + to_string(json[it.first]) +
") should be integer";
") should be integer(64bit)";
show_err_msg(msg);
return Status::type_conflict_in_json;
}
if (ptr->range.has_value()) {
if (json[it.first].get<long>() > std::numeric_limits<CFG_INT::value_type>::max()) {
if (json[it.first].get<int64_t>() > std::numeric_limits<CFG_INT::value_type>::max()) {
std::string msg = "Arithmetic overflow: param '" + it.first + "' (" +
to_string(json[it.first]) + ") should not bigger than " +
std::to_string(std::numeric_limits<CFG_INT::value_type>::max());
Expand All @@ -346,6 +375,54 @@ class Config {
}
}

if (const Entry<CFG_INT64>* ptr = std::get_if<Entry<CFG_INT64>>(&var)) {
if (!(type & ptr->type)) {
continue;
}
if (json.find(it.first) == json.end()) {
if (!ptr->default_val.has_value()) {
if (ptr->allow_empty_without_default) {
continue;
}
std::string msg = "param '" + it.first + "' not exist in json";
show_err_msg(msg);
return Status::invalid_param_in_json;
} else {
*ptr->val = ptr->default_val;
continue;
}
}
if (!json[it.first].is_number_integer()) {
std::string msg = "Type conflict in json: param '" + it.first + "' (" + to_string(json[it.first]) +
") should be unsigned integer";
show_err_msg(msg);
return Status::type_conflict_in_json;
}
if (ptr->range.has_value()) {
if (json[it.first].get<int64_t>() > std::numeric_limits<CFG_INT64::value_type>::max()) {
std::string msg = "Arithmetic overflow: param '" + it.first + "' (" +
to_string(json[it.first]) + ") should not bigger than " +
std::to_string(std::numeric_limits<CFG_INT64::value_type>::max());
show_err_msg(msg);
return Status::arithmetic_overflow;
}
CFG_INT64::value_type v = json[it.first];
auto range_val = ptr->range.value();
if (range_val.first <= v && v <= range_val.second) {
*ptr->val = v;
} else {
std::string msg = "Out of range in json: param '" + it.first + "' (" +
to_string(json[it.first]) + ") should be in range [" +
std::to_string(range_val.first) + ", " + std::to_string(range_val.second) +
"]";
show_err_msg(msg);
return Status::out_of_range_in_json;
}
} else {
*ptr->val = json[it.first];
}
}

if (const Entry<CFG_FLOAT>* ptr = std::get_if<Entry<CFG_FLOAT>>(&var)) {
if (!(type & ptr->type)) {
continue;
Expand Down Expand Up @@ -478,8 +555,8 @@ class Config {
virtual ~Config() {
}

using VarEntry = std::variant<Entry<CFG_STRING>, Entry<CFG_FLOAT>, Entry<CFG_INT>, Entry<CFG_BOOL>,
Entry<CFG_MATERIALIZED_VIEW_SEARCH_INFO_TYPE>>;
using VarEntry = std::variant<Entry<CFG_STRING>, Entry<CFG_FLOAT>, Entry<CFG_INT>, Entry<CFG_INT64>,
Entry<CFG_BOOL>, Entry<CFG_MATERIALIZED_VIEW_SEARCH_INFO_TYPE>>;
std::unordered_map<std::string, VarEntry> __DICT__;

protected:
Expand All @@ -501,7 +578,7 @@ const float defaultRangeFilter = 1.0f / 0.0;

class BaseConfig : public Config {
public:
CFG_INT dim; // just used for config verify
CFG_INT64 dim; // just used for config verify
CFG_STRING metric_type;
CFG_INT k;
CFG_INT num_build_thread;
Expand Down
9 changes: 7 additions & 2 deletions src/common/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,17 @@ Config::FormatAndCheck(const Config& cfg, Json& json, std::string* const err_msg
}
if (v < std::numeric_limits<CFG_INT::value_type>::min() ||
v > std::numeric_limits<CFG_INT::value_type>::max()) {
*err_msg = "integer value out of range, key: '" + key_str + "', value: '" + value_str + "'";
if (err_msg) {
*err_msg =
"integer value out of range, key: '" + key_str + "', value: '" + value_str + "'";
}
return knowhere::Status::invalid_value_in_json;
}
json[key_str] = static_cast<CFG_INT::value_type>(v);
} catch (const std::out_of_range&) {
*err_msg = "integer value out of range, key: '" + key_str + "', value: '" + value_str + "'";
if (err_msg) {
*err_msg = "integer value out of range, key: '" + key_str + "', value: '" + value_str + "'";
}
return knowhere::Status::invalid_value_in_json;
} catch (const std::invalid_argument&) {
KNOWHERE_THROW_MSG("invalid integer value, key: '" + key_str + "', value: '" + value_str + "'");
Expand Down
6 changes: 4 additions & 2 deletions src/index/diskann/diskann_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,10 @@ class DiskANNConfig : public BaseConfig {
if (!search_list_size.has_value()) {
search_list_size = std::max(k.value(), kSearchListSizeMinValue);
} else if (k.value() > search_list_size.value()) {
*err_msg = "search_list_size(" + std::to_string(search_list_size.value()) +
") should be larger than k(" + std::to_string(k.value()) + ")";
if (err_msg) {
*err_msg = "search_list_size(" + std::to_string(search_list_size.value()) +
") should be larger than k(" + std::to_string(k.value()) + ")";
}
LOG_KNOWHERE_ERROR_ << *err_msg;
return Status::out_of_range_in_json;
}
Expand Down
4 changes: 3 additions & 1 deletion src/index/gpu_raft/gpu_raft_brute_force_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ struct GpuRaftBruteForceConfig : public BaseConfig {
constexpr std::array<std::string_view, 2> legal_metric_list{"L2", "IP"};
std::string metric = metric_type.value();
if (std::find(legal_metric_list.begin(), legal_metric_list.end(), metric) == legal_metric_list.end()) {
*err_msg = "metric type " + metric + " not found or not supported, supported: [L2 IP]";
if (err_msg) {
*err_msg = "metric type " + metric + " not found or not supported, supported: [L2 IP]";
}
return Status::invalid_metric_type;
}
}
Expand Down
6 changes: 4 additions & 2 deletions src/index/gpu_raft/gpu_raft_cagra_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,10 @@ struct GpuRaftCagraConfig : public BaseConfig {

if (search_width.has_value()) {
if (std::max(itopk_size.value(), kAlignFactor * search_width.value()) < k.value()) {
*err_msg = "max((itopk_size + 31)// 32, search_width) * 32< topk";
LOG_KNOWHERE_ERROR_ << *err_msg;
if (err_msg) {
*err_msg = "max((itopk_size + 31)// 32, search_width) * 32< topk";
LOG_KNOWHERE_ERROR_ << *err_msg;
}
return Status::out_of_range_in_json;
}
} else {
Expand Down
4 changes: 3 additions & 1 deletion src/index/gpu_raft/gpu_raft_ivf_flat_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ struct GpuRaftIvfFlatConfig : public IvfFlatConfig {
constexpr std::array<std::string_view, 2> legal_metric_list{"L2", "IP"};
std::string metric = metric_type.value();
if (std::find(legal_metric_list.begin(), legal_metric_list.end(), metric) == legal_metric_list.end()) {
*err_msg = "metric type " + metric + " not found or not supported, supported: [L2 IP]";
if (err_msg) {
*err_msg = "metric type " + metric + " not found or not supported, supported: [L2 IP]";
}
return Status::invalid_metric_type;
}
}
Expand Down
4 changes: 3 additions & 1 deletion src/index/gpu_raft/gpu_raft_ivf_pq_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ struct GpuRaftIvfPqConfig : public IvfPqConfig {
constexpr std::array<std::string_view, 2> legal_metric_list{"L2", "IP"};
std::string metric = metric_type.value();
if (std::find(legal_metric_list.begin(), legal_metric_list.end(), metric) == legal_metric_list.end()) {
*err_msg = "metric type " + metric + " not found or not supported, supported: [L2 IP]";
if (err_msg) {
*err_msg = "metric type " + metric + " not found or not supported, supported: [L2 IP]";
}
return Status::invalid_metric_type;
}
}
Expand Down
26 changes: 17 additions & 9 deletions src/index/hnsw/faiss_hnsw_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,11 @@ class FaissHnswConfig : public BaseConfig {
if (!ef.has_value()) {
ef = std::max(k.value(), kEfMinValue);
} else if (k.value() > ef.value()) {
*err_msg = "ef(" + std::to_string(ef.value()) + ") should be larger than k(" +
std::to_string(k.value()) + ")";
LOG_KNOWHERE_ERROR_ << *err_msg;
if (err_msg) {
*err_msg = "ef(" + std::to_string(ef.value()) + ") should be larger than k(" +
std::to_string(k.value()) + ")";
LOG_KNOWHERE_ERROR_ << *err_msg;
}
return Status::out_of_range_in_json;
}
break;
Expand Down Expand Up @@ -140,8 +142,10 @@ class FaissHnswFlatConfig : public FaissHnswConfig {
if (param_type == PARAM_TYPE::TRAIN) {
// prohibit refine
if (refine.value_or(false) || refine_type.has_value() || refine_k.has_value()) {
*err_msg = "refine is not supported for this index";
LOG_KNOWHERE_ERROR_ << *err_msg;
if (err_msg) {
*err_msg = "refine is not supported for this index";
LOG_KNOWHERE_ERROR_ << *err_msg;
}
return Status::invalid_value_in_json;
}
}
Expand Down Expand Up @@ -174,16 +178,20 @@ class FaissHnswSqConfig : public FaissHnswConfig {
if (param_type == PARAM_TYPE::TRAIN) {
auto sq_type_v = sq_type.value();
if (!WhetherAcceptableQuantType(sq_type_v)) {
*err_msg = "invalid scalar quantizer type";
LOG_KNOWHERE_ERROR_ << *err_msg;
if (err_msg) {
*err_msg = "invalid scalar quantizer type";
LOG_KNOWHERE_ERROR_ << *err_msg;
}
return Status::invalid_value_in_json;
}

// check refine
if (refine_type.has_value()) {
if (!WhetherAcceptableRefineType(refine_type.value())) {
*err_msg = "invalid refine type type";
LOG_KNOWHERE_ERROR_ << *err_msg;
if (err_msg) {
*err_msg = "invalid refine type type";
LOG_KNOWHERE_ERROR_ << *err_msg;
}
return Status::invalid_value_in_json;
}
}
Expand Down
8 changes: 5 additions & 3 deletions src/index/hnsw/hnsw_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,11 @@ class HnswConfig : public BaseConfig {
if (!ef.has_value()) {
ef = std::max(k.value(), kEfMinValue);
} else if (k.value() > ef.value()) {
*err_msg = "ef(" + std::to_string(ef.value()) + ") should be larger than k(" +
std::to_string(k.value()) + ")";
LOG_KNOWHERE_ERROR_ << *err_msg;
if (err_msg) {
*err_msg = "ef(" + std::to_string(ef.value()) + ") should be larger than k(" +
std::to_string(k.value()) + ")";
LOG_KNOWHERE_ERROR_ << *err_msg;
}
return Status::out_of_range_in_json;
}
break;
Expand Down
6 changes: 4 additions & 2 deletions src/index/index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,17 @@ template <typename T>
inline Status
Index<T>::Train(const DataSetPtr dataset, const Json& json) {
auto cfg = this->node->CreateConfig();
RETURN_IF_ERROR(LoadConfig(cfg.get(), json, knowhere::TRAIN, "Train"));
std::string msg;
RETURN_IF_ERROR(LoadConfig(cfg.get(), json, knowhere::TRAIN, "Train", &msg));
return this->node->Train(dataset, std::move(cfg));
}

template <typename T>
inline Status
Index<T>::Add(const DataSetPtr dataset, const Json& json) {
auto cfg = this->node->CreateConfig();
RETURN_IF_ERROR(LoadConfig(cfg.get(), json, knowhere::TRAIN, "Add"));
std::string msg;
RETURN_IF_ERROR(LoadConfig(cfg.get(), json, knowhere::TRAIN, "Add", &msg));
return this->node->Add(dataset, std::move(cfg));
}

Expand Down
25 changes: 17 additions & 8 deletions src/index/ivf/ivf_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,11 @@ class IvfPqConfig : public IvfConfig {
int vec_dim = dim.value();
int param_m = m.value();
if (vec_dim % param_m != 0) {
*err_msg =
"dimension must be able to be divided by `m`, dimension: " + std::to_string(vec_dim) +
", m: " + std::to_string(param_m);
if (err_msg) {
*err_msg =
"dimension must be able to be divided by `m`, dimension: " + std::to_string(vec_dim) +
", m: " + std::to_string(param_m);
}
return Status::invalid_args;
}
}
Expand Down Expand Up @@ -115,7 +117,10 @@ class ScannConfig : public IvfFlatConfig {
if (dim.has_value()) {
int vec_dim = dim.value();
if (vec_dim % 2 != 0) {
*err_msg = "dimension must be able to be divided by 2, dimension:" + std::to_string(vec_dim);
if (err_msg) {
*err_msg =
"dimension must be able to be divided by 2, dimension:" + std::to_string(vec_dim);
}
return Status::invalid_args;
}
}
Expand Down Expand Up @@ -161,7 +166,9 @@ class IvfBinConfig : public IvfConfig {
constexpr std::array<std::string_view, 2> legal_metric_list{"HAMMING", "JACCARD"};
std::string metric = metric_type.value();
if (std::find(legal_metric_list.begin(), legal_metric_list.end(), metric) == legal_metric_list.end()) {
*err_msg = "metric type " + metric + " not found or not supported, supported: [HAMMING JACCARD]";
if (err_msg) {
*err_msg = "metric type " + metric + " not found or not supported, supported: [HAMMING JACCARD]";
}
return Status::invalid_metric_type;
}
}
Expand Down Expand Up @@ -195,9 +202,11 @@ class IvfSqCcConfig : public IvfFlatCcConfig {
auto legal_code_size_list = std::vector<int>{4, 6, 8, 16};
if (std::find(legal_code_size_list.begin(), legal_code_size_list.end(), code_size_v) ==
legal_code_size_list.end()) {
*err_msg =
"compress a vector into (code_size * dim)/8 bytes, code size value should be in 4, 6, 8 and 16";
LOG_KNOWHERE_ERROR_ << *err_msg;
if (err_msg) {
*err_msg =
"compress a vector into (code_size * dim)/8 bytes, code size value should be in 4, 6, 8 and 16";
LOG_KNOWHERE_ERROR_ << *err_msg;
}
return Status::invalid_value_in_json;
}
}
Expand Down
14 changes: 14 additions & 0 deletions tests/ut/test_config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,20 @@ TEST_CASE("Test config json parse", "[config]") {
CHECK(s == knowhere::Status::success);
}

SECTION("check int64 json values") {
auto unsigned_int_json_str = GENERATE(as<std::string>{},
R"({
"dim": 10000000000
})");
knowhere::BaseConfig test_config;
knowhere::Json test_json = knowhere::Json::parse(unsigned_int_json_str);
s = knowhere::Config::FormatAndCheck(test_config, test_json);
CHECK(s == knowhere::Status::success);
s = knowhere::Config::Load(test_config, test_json, knowhere::TRAIN);
CHECK(s == knowhere::Status::success);
CHECK(test_config.dim.value() == 10000000000L);
}

SECTION("check invalid json values") {
auto invalid_json_str = GENERATE(as<std::string>{},
R"({
Expand Down

0 comments on commit 616f5e3

Please sign in to comment.