123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183 |
- """Preprocesses the Shakespeare dataset for federated training.
- These codes are adopted from LEAF with some modifications.
- """
- import collections
- import json
- import os
- import re
- RANDOM_SEED = 1234
- # Regular expression to capture an actors name, and line continuation
- CHARACTER_RE = re.compile(r'^ ([a-zA-Z][a-zA-Z ]*)\. (.*)')
- CONT_RE = re.compile(r'^ (.*)')
- # The Comedy of Errors has errors in its indentation so we need to use
- # different regular expressions.
- COE_CHARACTER_RE = re.compile(r'^([a-zA-Z][a-zA-Z ]*)\. (.*)')
- COE_CONT_RE = re.compile(r'^(.*)')
- def _match_character_regex(line, comedy_of_errors=False):
- return (COE_CHARACTER_RE.match(line) if comedy_of_errors
- else CHARACTER_RE.match(line))
- def _match_continuation_regex(line, comedy_of_errors=False):
- return (
- COE_CONT_RE.match(line) if comedy_of_errors else CONT_RE.match(line))
- def _split_into_plays(shakespeare_full):
- """Splits the full data by play."""
- # List of tuples (play_name, dict from character to list of lines)
- plays = []
- discarded_lines = [] # Track discarded lines.
- slines = shakespeare_full.splitlines(True)[1:]
- # skip contents, the sonnets, and all's well that ends well
- author_count = 0
- start_i = 0
- for i, l in enumerate(slines):
- if 'by William Shakespeare' in l:
- author_count += 1
- if author_count == 2:
- start_i = i - 5
- break
- slines = slines[start_i:]
- current_character = None
- comedy_of_errors = False
- for i, line in enumerate(slines):
- # This marks the end of the plays in the file.
- if i > 124195 - start_i:
- break
- # This is a pretty good heuristic for detecting the start of a new play:
- if 'by William Shakespeare' in line:
- current_character = None
- characters = collections.defaultdict(list)
- # The title will be 2, 3, 4, 5, 6, or 7 lines above "by William Shakespeare".
- if slines[i - 2].strip():
- title = slines[i - 2]
- elif slines[i - 3].strip():
- title = slines[i - 3]
- elif slines[i - 4].strip():
- title = slines[i - 4]
- elif slines[i - 5].strip():
- title = slines[i - 5]
- elif slines[i - 6].strip():
- title = slines[i - 6]
- else:
- title = slines[i - 7]
- title = title.strip()
- assert title, ('Parsing error on line %d. Expecting title 2 or 3 lines above.' % i)
- comedy_of_errors = (title == 'THE COMEDY OF ERRORS')
- # Degenerate plays are removed at the end of the method.
- plays.append((title, characters))
- continue
- match = _match_character_regex(line, comedy_of_errors)
- if match:
- character, snippet = match.group(1), match.group(2)
- # Some character names are written with multiple casings, e.g., SIR_Toby
- # and SIR_TOBY. To normalize the character names, we uppercase each name.
- # Note that this was not done in the original preprocessing and is a
- # recent fix.
- character = character.upper()
- if not (comedy_of_errors and character.startswith('ACT ')):
- characters[character].append(snippet)
- current_character = character
- continue
- else:
- current_character = None
- continue
- elif current_character:
- match = _match_continuation_regex(line, comedy_of_errors)
- if match:
- if comedy_of_errors and match.group(1).startswith('<'):
- current_character = None
- continue
- else:
- characters[current_character].append(match.group(1))
- continue
- # Didn't consume the line.
- line = line.strip()
- if line and i > 2646:
- # Before 2646 are the sonnets, which we expect to discard.
- discarded_lines.append('%d:%s' % (i, line))
- # Remove degenerate "plays".
- return [play for play in plays if len(play[1]) > 1], discarded_lines
- def _remove_nonalphanumerics(filename):
- return re.sub('\\W+', '_', filename)
- def play_and_character(play, character):
- return _remove_nonalphanumerics((play + '_' + character).replace(' ', '_'))
- def _get_train_test_by_character(plays, test_fraction=0.2):
- """
- Splits character data into train and test sets.
- if test_fraction <= 0, returns {} for all_test_examples
- plays := list of (play, dict) tuples where play is a string and dict
- is a dictionary with character names as keys
- """
- skipped_characters = 0
- all_train_examples = collections.defaultdict(list)
- all_test_examples = collections.defaultdict(list)
- def add_examples(example_dict, example_tuple_list):
- for play, character, sound_bite in example_tuple_list:
- example_dict[play_and_character(
- play, character)].append(sound_bite)
- users_and_plays = {}
- for play, characters in plays:
- curr_characters = list(characters.keys())
- for c in curr_characters:
- users_and_plays[play_and_character(play, c)] = play
- for character, sound_bites in characters.items():
- examples = [(play, character, sound_bite)
- for sound_bite in sound_bites]
- if len(examples) <= 2:
- skipped_characters += 1
- # Skip characters with fewer than 2 lines since we need at least one
- # train and one test line.
- continue
- train_examples = examples
- if test_fraction > 0:
- num_test = max(int(len(examples) * test_fraction), 1)
- train_examples = examples[:-num_test]
- test_examples = examples[-num_test:]
- assert len(test_examples) == num_test
- assert len(train_examples) >= len(test_examples)
- add_examples(all_test_examples, test_examples)
- add_examples(all_train_examples, train_examples)
- return users_and_plays, all_train_examples, all_test_examples
- def _write_data_by_character(examples, output_directory):
- """Writes a collection of data files by play & character."""
- if not os.path.exists(output_directory):
- os.makedirs(output_directory)
- for character_name, sound_bites in examples.items():
- filename = os.path.join(output_directory, character_name + '.txt')
- with open(filename, 'w') as output:
- for sound_bite in sound_bites:
- output.write(sound_bite + '\n')
- def shakespeare_preprocess(input_filename, output_directory):
- print('Splitting .txt data between users')
- input_filename = input_filename
- with open(input_filename, 'r') as input_file:
- shakespeare_full = input_file.read()
- plays, discarded_lines = _split_into_plays(shakespeare_full)
- print('Discarded %d lines' % len(discarded_lines))
- users_and_plays, all_examples, _ = _get_train_test_by_character(plays, test_fraction=-1.0)
- with open(os.path.join(output_directory, 'users_and_plays.json'), 'w') as ouf:
- json.dump(users_and_plays, ouf)
- _write_data_by_character(all_examples,
- os.path.join(output_directory,
- 'by_play_and_character/'))
|