Skip to content

Commit

Permalink
Add tensors that test QAT.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 648803095
  • Loading branch information
lingvo-bot authored and copybara-github committed Jul 2, 2024
1 parent bc1cfd8 commit b162979
Show file tree
Hide file tree
Showing 11 changed files with 72 additions and 0 deletions.
2 changes: 2 additions & 0 deletions lingvo/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1475,6 +1475,7 @@ pytype_strict_test(
size = "medium",
srcs = ["py_utils_test.py"],
args = ["--noenable_eager_execution"],
data = ["//lingvo/core/testdata:quantization_test_data"],
deps = [
":py_utils_test_lib",
# Implicit freezegun dependency.
Expand All @@ -1487,6 +1488,7 @@ pytype_strict_test(
name = "py_utils_eager_test",
srcs = ["py_utils_test.py"],
args = ["--enable_eager_execution"],
data = ["//lingvo/core/testdata:quantization_test_data"],
main = "py_utils_test.py",
deps = [
":py_utils_test_lib",
Expand Down
63 changes: 63 additions & 0 deletions lingvo/core/py_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1352,6 +1352,69 @@ def testQAT(self, qat_output, expected):
)
self.assertAllClose(self.evaluate(x), expected)

@parameterized.named_parameters(
(
'4bit_weight_qat_output_false',
False,
'core/testdata/qat_test_4bit_weights.npy',
'core/testdata/qat_test_output_4bit_weight_qat_false.npy',
),
(
'4bit_weight_qat_output_true',
True,
'core/testdata/qat_test_4bit_weights.npy',
'core/testdata/qat_test_output_4bit_weight_qat_true.npy',
),
(
'8bit_weight_qat_output_false',
False,
'core/testdata/qat_test_8bit_weights.npy',
'core/testdata/qat_test_output_8bit_weight_qat_false.npy',
),
(
'8bit_weight_qat_output_true',
True,
'core/testdata/qat_test_8bit_weights.npy',
'core/testdata/qat_test_output_8bit_weight_qat_true.npy',
),
)
def testEinsumQuantization(self, qat_output, weights_path, expected):
# num_tasks=1, input_dim=2, output_dim=3
weights_path = test_helper.test_src_dir_path(weights_path)
weights = tf.convert_to_tensor(np.load(weights_path), tf.float32)
bias_path = test_helper.test_src_dir_path('core/testdata/qat_test_bias.npy')
bias = tf.convert_to_tensor(np.load(bias_path), tf.float32)
inputs_path = test_helper.test_src_dir_path(
'core/testdata/qat_test_inputs.npy'
)
inputs = tf.convert_to_tensor(np.load(inputs_path), tf.float32)
output_path = test_helper.test_src_dir_path(expected)
output = tf.convert_to_tensor(np.load(output_path), tf.float32)

quant_layer_p = layers.MultitaskProjectionEinsumLayer.Params()
quant_layer_p.name = 'testQAT'
quant_layer_p.input_dim = 256
quant_layer_p.output_dim = 126
quant_layer_p.num_tasks = 8

with self.session(use_gpu=False):
x = self.evaluate(
py_utils.MultiTaskProjection(
weights=weights,
biases=bias,
inputs=inputs,
tasks=1,
einsum_order='select_and_multiply',
quant_layer=layers.MultitaskProjectionEinsumLayer(quant_layer_p),
w_q_name='w',
w_q_domain='default',
qat_output=qat_output,
)
)
# different server CPUs produce slightly different results, e-3 is a safe
# margin since outputs are in the order of e+4
self.assertAllClose(x, output, atol=2.5e-3)

def testShardedFilePatternToGlob(self):
file_pattern = '/some/path/to/file@8'
self.assertEqual('/some/path/to/file-?????-of-00008',
Expand Down
7 changes: 7 additions & 0 deletions lingvo/core/testdata/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,10 @@ filegroup(
"en-1k.spm.*",
]),
)

filegroup(
name = "quantization_test_data",
data = glob([
"qat_test_*",
]),
)
Binary file added lingvo/core/testdata/qat_test_4bit_weights.npy
Binary file not shown.
Binary file added lingvo/core/testdata/qat_test_8bit_weights.npy
Binary file not shown.
Binary file added lingvo/core/testdata/qat_test_bias.npy
Binary file not shown.
Binary file added lingvo/core/testdata/qat_test_inputs.npy
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 comments on commit b162979

Please sign in to comment.