1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859 |
- #
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- from fate_arch.computing import is_table
- from federatedml.util import LOGGER
- def check_schema(input_schema, output_schema):
- LOGGER.debug(f"input schema: {input_schema} -> output schema: {output_schema}")
- if output_schema is None:
- raise EnvironmentError(
- f"output_schema is None while input data has schema.")
- input_header = input_schema.get("header", None)
- output_header = output_schema.get("header", None)
- if input_header is not None and output_header is None:
- raise EnvironmentError(
- f"output header is None while input data has header.")
- def assert_schema_consistent(func):
- def _func(*args, **kwargs):
- input_schema = None
- all_args = []
- all_args.extend(args)
- all_args.extend(kwargs.values())
- for arg in all_args:
- if is_table(arg):
- input_schema = arg.schema
- break
- result = func(*args, **kwargs)
- if input_schema is not None:
- # single data set
- if is_table(result) and result.count() > 0:
- output_schema = result.schema
- check_schema(input_schema, output_schema)
- # multiple data sets
- elif type(result).__name__ in ["list", "tuple"]:
- for output_data in result:
- if is_table(output_data) and output_data.count() > 0:
- output_schema = output_data.schema
- check_schema(input_schema, output_schema)
- return result
- return _func
|