123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204 |
- #
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import numpy as np
- import unittest
- from federatedml.secureprotol.fixedpoint import FixedPointNumber
- class TestFixedPointNumber(unittest.TestCase):
- def setUp(self):
- unittest.TestCase.setUp(self)
- def tearDown(self):
- unittest.TestCase.tearDown(self)
- def test_encode_decode(self):
- for i in range(100):
- en_i = FixedPointNumber.encode(i)
- de_en_i = en_i.decode()
- self.assertEqual(de_en_i, i)
- en_i = FixedPointNumber.encode(-i)
- de_en_i = en_i.decode()
- self.assertEqual(de_en_i, -i)
- for i in range(100):
- x = i * 0.6
- en_x = FixedPointNumber.encode(x)
- de_en_x = en_x.decode()
- self.assertAlmostEqual(de_en_x, x)
- elem = np.ones(100) * np.random.rand()
- for x in elem:
- en_x = FixedPointNumber.encode(x)
- de_en_x = en_x.decode()
- self.assertAlmostEqual(de_en_x, x)
- elem = np.ones(100) * np.random.randint(100)
- for x in elem:
- en_x = FixedPointNumber.encode(x)
- de_en_x = en_x.decode()
- self.assertAlmostEqual(de_en_x, x)
- def test_add(self):
- x_li = np.ones(100) * np.random.randint(100)
- y_li = np.ones(100) * np.random.randint(1000)
- z_li = np.ones(100) * np.random.rand()
- t_li = range(100)
- for i in range(x_li.shape[0]):
- x = x_li[i]
- y = y_li[i]
- z = z_li[i]
- t = t_li[i]
- en_x = FixedPointNumber.encode(x)
- en_y = FixedPointNumber.encode(y)
- en_z = FixedPointNumber.encode(-z)
- en_t = FixedPointNumber.encode(-t)
- en_res = en_x + en_y + en_z + en_t
- res = x + y + (-z) + (-t)
- de_en_res = en_res.decode()
- self.assertAlmostEqual(de_en_res, res)
- def test_sub(self):
- x_li = np.ones(100) * np.random.randint(100)
- y_li = np.ones(100) * np.random.randint(1000)
- z_li = np.ones(100) * np.random.rand()
- t_li = range(100)
- for i in range(x_li.shape[0]):
- x = x_li[i]
- y = y_li[i]
- z = z_li[i]
- t = t_li[i]
- en_x = FixedPointNumber.encode(x)
- en_y = FixedPointNumber.encode(y)
- en_z = FixedPointNumber.encode(z)
- en_t = FixedPointNumber.encode(t)
- en_res = en_x - en_y - en_z - en_t
- res = x - y - z - t
- de_en_res = en_res.decode()
- self.assertAlmostEqual(de_en_res, res)
- def test_mul(self):
- x_li = np.ones(100) * np.random.randint(100)
- y_li = np.ones(100) * np.random.randint(1000) * -1
- z_li = np.ones(100) * np.random.rand()
- t_li = range(0, 100)
- for i in range(x_li.shape[0]):
- x = x_li[i]
- y = y_li[i]
- z = z_li[i]
- t = t_li[i]
- en_x = FixedPointNumber.encode(x)
- en_res = (en_x * y + z) * t
- res = (x * y + z) * t
- de_en_res = en_res.decode()
- self.assertAlmostEqual(de_en_res, res)
- x = 9
- en_x = FixedPointNumber.encode(x)
- for i in range(100):
- en_x = en_x + 5000 - 0.2
- x = x + 5000 - 0.2
- de_en_x = en_x.decode()
- self.assertAlmostEqual(de_en_x, x)
- def test_div(self):
- for i in range(100):
- x = np.random.randn() * 100
- y = np.random.randn() * 100
- en_x = FixedPointNumber.encode(x)
- en_y = FixedPointNumber.encode(y)
- z = x / y
- en_z = en_x / en_y
- de_en_z = en_z.decode()
- self.assertAlmostEqual(de_en_z, z)
- def test_lt(self):
- for i in range(100):
- x = np.random.randn() * 100
- y = np.random.randn() * 100
- en_x = FixedPointNumber.encode(x)
- en_y = FixedPointNumber.encode(y)
- z = x < y
- en_z = en_x < en_y
- self.assertEqual(en_z, z)
- def test_gt(self):
- for i in range(100):
- x = np.random.randn() * 100
- y = np.random.randn() * 100
- en_x = FixedPointNumber.encode(x)
- en_y = FixedPointNumber.encode(y)
- z = x > y
- en_z = en_x > en_y
- self.assertEqual(en_z, z)
- def test_le(self):
- for i in range(100):
- x = np.random.randint(10)
- y = np.random.randint(10)
- en_x = FixedPointNumber.encode(x)
- en_y = FixedPointNumber.encode(y)
- z = x <= y
- en_z = en_x <= en_y
- self.assertEqual(en_z, z)
- def test_ge(self):
- for i in range(100):
- x = np.random.randint(10)
- y = np.random.randint(10)
- en_x = FixedPointNumber.encode(x)
- en_y = FixedPointNumber.encode(y)
- z = x >= y
- en_z = en_x >= en_y
- self.assertEqual(en_z, z)
- def test_eq(self):
- for i in range(100):
- x = np.random.randint(10)
- y = np.random.randint(10)
- en_x = FixedPointNumber.encode(x)
- en_y = FixedPointNumber.encode(y)
- z = x == y
- en_z = en_x == en_y
- self.assertEqual(en_z, z)
- def test_ne(self):
- for i in range(100):
- x = np.random.randint(10)
- y = np.random.randint(10)
- en_x = FixedPointNumber.encode(x)
- en_y = FixedPointNumber.encode(y)
- z = x != y
- en_z = en_x != en_y
- self.assertEqual(en_z, z)
- if __name__ == '__main__':
- unittest.main()
|