From f92fc368ba21e3638cae92cc7715f9ce752e5b83 Mon Sep 17 00:00:00 2001 From: lhpqaq Date: Mon, 12 Jan 2026 14:44:23 +0800 Subject: [PATCH] Support Mixed-Precision Quantization --- examples/common-ggml.cpp | 281 ++++++++++++++++++++++++++++++++- examples/common-ggml.h | 19 +++ examples/quantize/README.md | 40 +++++ examples/quantize/quantize.cpp | 93 +++++++++-- src/whisper.cpp | 53 +++++-- 5 files changed, 464 insertions(+), 22 deletions(-) diff --git a/examples/common-ggml.cpp b/examples/common-ggml.cpp index c42b644fedd..d7a9a8d70c1 100644 --- a/examples/common-ggml.cpp +++ b/examples/common-ggml.cpp @@ -2,6 +2,8 @@ #include #include +#include +#include static const std::map GGML_FTYPE_MAP = { {"q4_0", GGML_FTYPE_MOSTLY_Q4_0}, @@ -16,6 +18,21 @@ static const std::map GGML_FTYPE_MAP = { {"q6_k", GGML_FTYPE_MOSTLY_Q6_K}, }; +static const std::map GGML_TYPE_MAP = { + {"q4_0", GGML_TYPE_Q4_0}, + {"q4_1", GGML_TYPE_Q4_1}, + {"q5_0", GGML_TYPE_Q5_0}, + {"q5_1", GGML_TYPE_Q5_1}, + {"q8_0", GGML_TYPE_Q8_0}, + {"q2_k", GGML_TYPE_Q2_K}, + {"q3_k", GGML_TYPE_Q3_K}, + {"q4_k", GGML_TYPE_Q4_K}, + {"q5_k", GGML_TYPE_Q5_K}, + {"q6_k", GGML_TYPE_Q6_K}, + {"f16", GGML_TYPE_F16}, + {"f32", GGML_TYPE_F32}, +}; + void ggml_print_ftypes(FILE * fp) { for (auto it = GGML_FTYPE_MAP.begin(); it != GGML_FTYPE_MAP.end(); it++) { fprintf(fp, " type = \"%s\" or %d\n", it->first.c_str(), it->second); @@ -38,6 +55,18 @@ enum ggml_ftype ggml_parse_ftype(const char * str) { return ftype; } +ggml_type ggml_parse_qtype(const char * str) { + std::string str_lower(str); + std::transform(str_lower.begin(), str_lower.end(), str_lower.begin(), ::tolower); + + const auto it = GGML_TYPE_MAP.find(str_lower); + if (it == GGML_TYPE_MAP.end()) { + fprintf(stderr, "%s: unknown qtype '%s'\n", __func__, str); + return GGML_TYPE_COUNT; + } + return it->second; +} + bool ggml_common_quantize_0( std::ifstream & finp, std::ofstream & fout, @@ -159,10 +188,13 @@ bool ggml_common_quantize_0( ttype = qtype; } else { - const int bpe = (ttype == 0) ? sizeof(float) : sizeof(uint16_t); + // For non-quantized tensors, we need to correctly calculate size based on type + // Use ggml_row_size to get the correct size for the tensor's row + const size_t row_size = ggml_row_size((ggml_type) ttype, ne[0]); + const size_t data_size = row_size * (nelements / ne[0]); - data_u8.resize(nelements*bpe); - finp.read(reinterpret_cast(data_u8.data()), nelements * bpe); + data_u8.resize(data_size); + finp.read(reinterpret_cast(data_u8.data()), data_size); } fout.write(reinterpret_cast(&n_dims), sizeof(n_dims)); @@ -238,3 +270,246 @@ bool ggml_common_quantize_0( return true; } + +// Extended quantization function with per-tensor quantization support +bool ggml_common_quantize_0( + std::ifstream & finp, + std::ofstream & fout, + const ggml_ftype ftype, + const std::vector & to_quant, + const std::vector & to_skip, + const std::vector & tensor_quant_specs) { + + ggml_type default_qtype = GGML_TYPE_F32; + + switch (ftype) { + case GGML_FTYPE_MOSTLY_Q4_0: default_qtype = GGML_TYPE_Q4_0; break; + case GGML_FTYPE_MOSTLY_Q4_1: default_qtype = GGML_TYPE_Q4_1; break; + case GGML_FTYPE_MOSTLY_Q5_0: default_qtype = GGML_TYPE_Q5_0; break; + case GGML_FTYPE_MOSTLY_Q5_1: default_qtype = GGML_TYPE_Q5_1; break; + case GGML_FTYPE_MOSTLY_Q8_0: default_qtype = GGML_TYPE_Q8_0; break; + case GGML_FTYPE_MOSTLY_Q2_K: default_qtype = GGML_TYPE_Q2_K; break; + case GGML_FTYPE_MOSTLY_Q3_K: default_qtype = GGML_TYPE_Q3_K; break; + case GGML_FTYPE_MOSTLY_Q4_K: default_qtype = GGML_TYPE_Q4_K; break; + case GGML_FTYPE_MOSTLY_Q5_K: default_qtype = GGML_TYPE_Q5_K; break; + case GGML_FTYPE_MOSTLY_Q6_K: default_qtype = GGML_TYPE_Q6_K; break; + case GGML_FTYPE_UNKNOWN: + case GGML_FTYPE_ALL_F32: + case GGML_FTYPE_MOSTLY_F16: + case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: + case GGML_FTYPE_MOSTLY_IQ2_XXS: + case GGML_FTYPE_MOSTLY_IQ2_XS: + case GGML_FTYPE_MOSTLY_IQ2_S: + case GGML_FTYPE_MOSTLY_IQ3_XXS: + case GGML_FTYPE_MOSTLY_IQ3_S: + case GGML_FTYPE_MOSTLY_IQ1_S: + case GGML_FTYPE_MOSTLY_IQ4_NL: + case GGML_FTYPE_MOSTLY_IQ4_XS: + case GGML_FTYPE_MOSTLY_IQ1_M: + case GGML_FTYPE_MOSTLY_BF16: + case GGML_FTYPE_MOSTLY_MXFP4: + { + fprintf(stderr, "%s: unsupported model type %d (ftype=%d)\n", __func__, ftype, ftype); + return false; + } + }; + + if (!ggml_is_quantized(default_qtype)) { + fprintf(stderr, "%s: invalid quantization type %d (%s)\n", __func__, default_qtype, ggml_type_name(default_qtype)); + return false; + } + + // Pre-compile regex patterns for efficiency + struct compiled_pattern { + std::regex regex; + ggml_type quant_type; + }; + std::vector compiled_patterns; + compiled_patterns.reserve(tensor_quant_specs.size()); + + for (const auto & spec : tensor_quant_specs) { + try { + compiled_patterns.push_back({std::regex(spec.pattern), spec.quant_type}); + } catch (const std::regex_error & e) { + fprintf(stderr, "%s: invalid regex pattern '%s': %s\n", __func__, spec.pattern.c_str(), e.what()); + return false; + } + } + + size_t total_size_org = 0; + size_t total_size_new = 0; + + std::vector work; + + std::vector data_u8; + std::vector data_f16; + std::vector data_f32; + + std::unordered_map quant_type_counts; + + while (true) { + int32_t n_dims; + int32_t length; + int32_t ttype; + + finp.read(reinterpret_cast(&n_dims), sizeof(n_dims)); + finp.read(reinterpret_cast(&length), sizeof(length)); + finp.read(reinterpret_cast(&ttype), sizeof(ttype)); + + if (finp.eof()) { + break; + } + + int32_t nelements = 1; + int32_t ne[4] = { 1, 1, 1, 1 }; + for (int i = 0; i < n_dims; ++i) { + finp.read (reinterpret_cast(&ne[i]), sizeof(ne[i])); + nelements *= ne[i]; + } + + std::string name(length, 0); + finp.read (&name[0], length); + + printf("%64s - [%5d, %5d, %5d], type = %6s ", name.data(), ne[0], ne[1], ne[2], ggml_type_name((ggml_type) ttype)); + + bool quantize = false; + ggml_type qtype = default_qtype; + + // check if we should quantize this tensor + for (const auto & s : to_quant) { + if (std::regex_match(name, std::regex(s))) { + quantize = true; + break; + } + } + + // check if we should skip this tensor + for (const auto & s : to_skip) { + if (std::regex_match(name, std::regex(s))) { + quantize = false; + break; + } + } + + // check for per-tensor quantization specification + if (quantize) { + for (const auto & cp : compiled_patterns) { + if (std::regex_match(name, cp.regex)) { + qtype = cp.quant_type; + printf("matched pattern -> %s ", ggml_type_name(qtype)); + break; + } + } + } + + // quantize only 2D tensors + quantize &= (n_dims == 2); + + if (quantize) { + if (ttype != GGML_TYPE_F32 && ttype != GGML_TYPE_F16) { + fprintf(stderr, "%s: unsupported ttype %d (%s) for integer quantization\n", __func__, ttype, ggml_type_name((ggml_type) ttype)); + return false; + } + + if (ttype == GGML_TYPE_F16) { + data_f16.resize(nelements); + finp.read(reinterpret_cast(data_f16.data()), nelements * sizeof(ggml_fp16_t)); + data_f32.resize(nelements); + for (int i = 0; i < nelements; ++i) { + data_f32[i] = ggml_fp16_to_fp32(data_f16[i]); + } + } else { + data_f32.resize(nelements); + finp.read(reinterpret_cast(data_f32.data()), nelements * sizeof(float)); + } + + ttype = qtype; + quant_type_counts[ggml_type_name(qtype)]++; + } else { + // For non-quantized tensors, we need to correctly calculate size based on type + // Use ggml_row_size to get the correct size for the tensor's row + const size_t row_size = ggml_row_size((ggml_type) ttype, ne[0]); + const size_t data_size = row_size * (nelements / ne[0]); + + data_u8.resize(data_size); + finp.read(reinterpret_cast(data_u8.data()), data_size); + } + + fout.write(reinterpret_cast(&n_dims), sizeof(n_dims)); + fout.write(reinterpret_cast(&length), sizeof(length)); + fout.write(reinterpret_cast(&ttype), sizeof(ttype)); + for (int i = 0; i < n_dims; ++i) { + fout.write(reinterpret_cast(&ne[i]), sizeof(ne[i])); + } + fout.write(&name[0], length); + + if (quantize) { + work.resize(nelements); // for quantization + + size_t cur_size = 0; + switch ((ggml_type) ttype) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + { + cur_size = ggml_quantize_chunk((ggml_type) ttype, data_f32.data(), work.data(), 0, nelements/ne[0], ne[0], nullptr); + } break; + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_I64: + case GGML_TYPE_F64: + case GGML_TYPE_Q8_1: + case GGML_TYPE_Q8_K: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_BF16: + case GGML_TYPE_TQ1_0: + case GGML_TYPE_TQ2_0: + case GGML_TYPE_MXFP4: + case GGML_TYPE_COUNT: + { + fprintf(stderr, "%s: unsupported quantization type %d (%s)\n", __func__, ttype, ggml_type_name((ggml_type) ttype)); + return false; + } + } + + fout.write(reinterpret_cast(work.data()), cur_size); + total_size_new += cur_size; + + printf("size = %8.2f MB -> %8.2f MB\n", nelements * sizeof(float)/1024.0/1024.0, cur_size/1024.0/1024.0); + } else { + printf("size = %8.3f MB\n", data_u8.size()/1024.0/1024.0); + fout.write(reinterpret_cast(data_u8.data()), data_u8.size()); + total_size_new += data_u8.size(); + } + + total_size_org += nelements * sizeof(float); + } + + printf("%s: model size = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0); + printf("%s: quant size = %8.2f MB | ftype = %d (%s)\n", __func__, total_size_new/1024.0/1024.0, ftype, ggml_type_name(default_qtype)); + + printf("%s: quantization type summary:\n", __func__); + for (const auto & kv : quant_type_counts) { + printf("%s: %s: %d tensors\n", __func__, kv.first.c_str(), kv.second); + } + + return true; +} diff --git a/examples/common-ggml.h b/examples/common-ggml.h index 477de341a1f..eba243ae2a2 100644 --- a/examples/common-ggml.h +++ b/examples/common-ggml.h @@ -5,9 +5,19 @@ #include #include #include +#include + +// Structure for per-tensor quantization specification +struct tensor_quant_spec { + std::string pattern; // regex pattern to match tensor names + ggml_type quant_type; // quantization type for matched tensors +}; enum ggml_ftype ggml_parse_ftype(const char * str); +// Parse a quantization type string (e.g., "q4_0", "q8_0") +ggml_type ggml_parse_qtype(const char * str); + void ggml_print_ftypes(FILE * fp = stderr); bool ggml_common_quantize_0( @@ -16,3 +26,12 @@ bool ggml_common_quantize_0( const ggml_ftype ftype, const std::vector & to_quant, const std::vector & to_skip); + +// Extended quantization function with per-tensor quantization support +bool ggml_common_quantize_0( + std::ifstream & finp, + std::ofstream & fout, + const ggml_ftype ftype, + const std::vector & to_quant, + const std::vector & to_skip, + const std::vector & tensor_quant_specs); diff --git a/examples/quantize/README.md b/examples/quantize/README.md index 2d918acf7f3..5ad7a12858d 100644 --- a/examples/quantize/README.md +++ b/examples/quantize/README.md @@ -1,3 +1,43 @@ # quantize Tool for integer quantization of Whisper `ggml` model files + +## Features + +- Standard uniform quantization (Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, Q2_K, Q3_K, Q4_K, Q5_K, Q6_K) +- **Mixed precision quantization** - quantize different layers with different quantization types (NEW!) + +## Basic Usage + +```bash +./quantize model-f32.bin model-quant.bin type +``` + +Where `type` is one of: q4_0, q4_1, q5_0, q5_1, q8_0, q2_k, q3_k, q4_k, q5_k, q6_k + +## Mixed Precision Quantization + +You can now specify different quantization types for different tensors using the `--tensor-type` option: + +```bash +./quantize [--tensor-type PATTERN=TYPE ...] model-f32.bin model-quant.bin default_type +``` + +### Examples + +**Quantize encoder with Q8_0 (higher quality) and decoder with Q4_0 (smaller size):** +```bash +./quantize \ + --tensor-type 'encoder\..*\.weight'=q8_0 \ + --tensor-type 'decoder\..*\.weight'=q4_0 \ + model-f32.bin model-mixed.bin q4_k +``` + +**Keep attention layers at higher precision:** +```bash +./quantize \ + --tensor-type '.*attn.*'=q8_0 \ + model-f32.bin model-mixed.bin q4_0 +``` + +For more detailed documentation and examples, see [README_MIXED_PRECISION.md](README_MIXED_PRECISION.md). diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 4dbae205551..a41896181d0 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -37,7 +37,11 @@ struct whisper_filters { }; // quantize a model -static bool whisper_model_quantize(const std::string & fname_inp, const std::string & fname_out, ggml_ftype ftype) { +static bool whisper_model_quantize( + const std::string & fname_inp, + const std::string & fname_out, + ggml_ftype ftype, + const std::vector & tensor_quant_specs = {}) { gpt_vocab vocab; printf("%s: loading model from '%s'\n", __func__, fname_inp.c_str()); @@ -83,7 +87,12 @@ static bool whisper_model_quantize(const std::string & fname_inp, const std::str finp.read((char *) &hparams.ftype, sizeof(hparams.ftype)); const int32_t qntvr_src = hparams.ftype / GGML_QNT_VERSION_FACTOR; - const int32_t ftype_dst = GGML_QNT_VERSION * GGML_QNT_VERSION_FACTOR + ftype; + + // For mixed precision quantization, use F16 as the base ftype to ensure + // all tensor buffers are large enough to hold any quantization type + const bool use_mixed_precision = !tensor_quant_specs.empty(); + const int32_t ftype_for_allocation = use_mixed_precision ? GGML_FTYPE_MOSTLY_F16 : ftype; + const int32_t ftype_dst = GGML_QNT_VERSION * GGML_QNT_VERSION_FACTOR + ftype_for_allocation; fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab); fprintf(stderr, "%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx); @@ -99,6 +108,9 @@ static bool whisper_model_quantize(const std::string & fname_inp, const std::str fprintf(stderr, "%s: qntvr (src) = %d\n", __func__, qntvr_src); fprintf(stderr, "%s: ftype (dst) = %d\n", __func__, ftype_dst); fprintf(stderr, "%s: qntvr (dst) = %d\n", __func__, GGML_QNT_VERSION); + if (use_mixed_precision) { + fprintf(stderr, "%s: using mixed precision quantization (ftype for allocation = F16)\n", __func__); + } fout.write((const char *) &hparams.n_vocab, sizeof(hparams.n_vocab)); fout.write((const char *) &hparams.n_audio_ctx, sizeof(hparams.n_audio_ctx)); @@ -165,7 +177,15 @@ static bool whisper_model_quantize(const std::string & fname_inp, const std::str "decoder.positional_embedding", }; - if (!ggml_common_quantize_0(finp, fout, ftype, { ".*" }, to_skip)) { + // Use the extended quantization function if we have per-tensor specs + bool success; + if (!tensor_quant_specs.empty()) { + success = ggml_common_quantize_0(finp, fout, ftype, { ".*" }, to_skip, tensor_quant_specs); + } else { + success = ggml_common_quantize_0(finp, fout, ftype, { ".*" }, to_skip); + } + + if (!success) { fprintf(stderr, "%s: failed to quantize model '%s'\n", __func__, fname_inp.c_str()); return false; } @@ -179,12 +199,67 @@ static bool whisper_model_quantize(const std::string & fname_inp, const std::str int main(int argc, char ** argv) { ggml_backend_load_all(); - if (argc != 4) { - fprintf(stderr, "usage: %s model-f32.bin model-quant.bin type\n", argv[0]); + if (argc < 4) { + fprintf(stderr, "usage: %s [--tensor-type PATTERN=TYPE ...] model-f32.bin model-quant.bin type\n", argv[0]); + fprintf(stderr, "\n"); + fprintf(stderr, " --tensor-type PATTERN=TYPE : specify quantization type for tensors matching PATTERN\n"); + fprintf(stderr, " PATTERN is a regex pattern to match tensor names\n"); + fprintf(stderr, " TYPE is a quantization type (e.g., q4_0, q8_0, f16)\n"); + fprintf(stderr, " Example: --tensor-type 'encoder\\..*\\.weight'=q8_0 --tensor-type 'decoder\\..*\\.weight'=q4_0\n"); + fprintf(stderr, "\n"); ggml_print_ftypes(stderr); return 1; } + // Parse optional arguments + std::vector tensor_quant_specs; + int arg_idx = 1; + + while (arg_idx < argc && strncmp(argv[arg_idx], "--", 2) == 0) { + if (strcmp(argv[arg_idx], "--tensor-type") == 0) { + if (arg_idx + 1 >= argc) { + fprintf(stderr, "error: --tensor-type requires an argument\n"); + return 1; + } + arg_idx++; + + // Parse PATTERN=TYPE + const char * spec_str = argv[arg_idx]; + const char * eq = strchr(spec_str, '='); + if (eq == nullptr) { + fprintf(stderr, "error: invalid --tensor-type format '%s', expected PATTERN=TYPE\n", spec_str); + return 1; + } + + std::string pattern(spec_str, eq - spec_str); + std::string type_str(eq + 1); + + ggml_type qtype = ggml_parse_qtype(type_str.c_str()); + if (qtype == GGML_TYPE_COUNT) { + fprintf(stderr, "error: unknown quantization type '%s'\n", type_str.c_str()); + return 1; + } + + tensor_quant_spec spec; + spec.pattern = pattern; + spec.quant_type = qtype; + tensor_quant_specs.push_back(spec); + + printf("Added tensor quantization spec: pattern='%s' type=%s\n", + pattern.c_str(), ggml_type_name(qtype)); + } else { + fprintf(stderr, "error: unknown option '%s'\n", argv[arg_idx]); + return 1; + } + arg_idx++; + } + + if (argc - arg_idx < 3) { + fprintf(stderr, "error: missing required arguments\n"); + fprintf(stderr, "usage: %s [--tensor-type PATTERN=TYPE ...] model-f32.bin model-quant.bin type\n", argv[0]); + return 1; + } + // needed to initialize f16 tables { struct ggml_init_params params = { 0, NULL, false }; @@ -192,10 +267,10 @@ int main(int argc, char ** argv) { ggml_free(ctx); } - const std::string fname_inp = argv[1]; - const std::string fname_out = argv[2]; + const std::string fname_inp = argv[arg_idx]; + const std::string fname_out = argv[arg_idx + 1]; - const ggml_ftype ftype = ggml_parse_ftype(argv[3]); + const ggml_ftype ftype = ggml_parse_ftype(argv[arg_idx + 2]); const int64_t t_main_start_us = ggml_time_us(); @@ -205,7 +280,7 @@ int main(int argc, char ** argv) { { const int64_t t_start_us = ggml_time_us(); - if (!whisper_model_quantize(fname_inp, fname_out, ggml_ftype(ftype))) { + if (!whisper_model_quantize(fname_inp, fname_out, ggml_ftype(ftype), tensor_quant_specs)) { fprintf(stderr, "%s: failed to quantize model from '%s'\n", __func__, fname_inp.c_str()); return 1; } diff --git a/src/whisper.cpp b/src/whisper.cpp index 5b6e4b4be48..1132c8e6c2e 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -1911,28 +1911,61 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con return false; } - const size_t bpe = ggml_type_size(ggml_type(ttype)); - - if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) { - WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", - __func__, name.data(), ggml_nbytes(tensor), nelements*bpe); - return false; + // Calculate size based on file's tensor type + const size_t file_tensor_size = ggml_row_size(ggml_type(ttype), ne[0]) * (nelements / ne[0]); + const size_t expected_tensor_size = ggml_nbytes(tensor); + + // For mixed precision models, the tensor type in file may differ from the type + // the tensor was created with. We need to handle this carefully. + if (tensor->type != ggml_type(ttype)) { + // Mixed precision: tensor created with one type, file has another + // We need to update the tensor's type to match the file + WHISPER_LOG_DEBUG("%s: tensor '%s' type mismatch (expected %s, file has %s)\n", + __func__, name.data(), ggml_type_name(tensor->type), ggml_type_name(ggml_type(ttype))); + + // Check if the allocated buffer is large enough for the file's data + if (file_tensor_size > expected_tensor_size) { + WHISPER_LOG_ERROR("%s: tensor '%s' buffer too small: allocated %zu bytes for %s, but file needs %zu bytes for %s\n", + __func__, name.data(), expected_tensor_size, ggml_type_name(tensor->type), + file_tensor_size, ggml_type_name(ggml_type(ttype))); + return false; + } + + // Update tensor type to match the file + tensor->type = ggml_type(ttype); + + // Update tensor strides (nb) based on new type + tensor->nb[0] = ggml_type_size(tensor->type); + tensor->nb[1] = tensor->nb[0] * (tensor->ne[0] / ggml_blck_size(tensor->type)); + for (int i = 2; i < GGML_MAX_DIMS; i++) { + tensor->nb[i] = tensor->nb[i-1] * tensor->ne[i-1]; + } + } else { + // Normal case: types match, verify size + if (file_tensor_size != expected_tensor_size) { + WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", + __func__, name.data(), expected_tensor_size, file_tensor_size); + return false; + } } + + // Now read the data - use the file's size + const size_t bytes_to_read = file_tensor_size; if (ggml_backend_buffer_is_host(tensor->buffer)) { // for the CPU and Metal backend, we can read directly into the tensor - loader->read(loader->context, tensor->data, ggml_nbytes(tensor)); + loader->read(loader->context, tensor->data, bytes_to_read); BYTESWAP_TENSOR(tensor); } else { // read into a temporary buffer first, then copy to device memory - read_buf.resize(ggml_nbytes(tensor)); + read_buf.resize(bytes_to_read); loader->read(loader->context, read_buf.data(), read_buf.size()); - ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor)); + ggml_backend_tensor_set(tensor, read_buf.data(), 0, bytes_to_read); } - total_size += ggml_nbytes(tensor); + total_size += bytes_to_read; model.n_loaded++; }