preprocess_shakespeare.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. """Preprocesses the Shakespeare dataset for federated training.
  2. These codes are adopted from LEAF with some modifications.
  3. """
  4. import collections
  5. import json
  6. import os
  7. import re
  8. RANDOM_SEED = 1234
  9. # Regular expression to capture an actors name, and line continuation
  10. CHARACTER_RE = re.compile(r'^ ([a-zA-Z][a-zA-Z ]*)\. (.*)')
  11. CONT_RE = re.compile(r'^ (.*)')
  12. # The Comedy of Errors has errors in its indentation so we need to use
  13. # different regular expressions.
  14. COE_CHARACTER_RE = re.compile(r'^([a-zA-Z][a-zA-Z ]*)\. (.*)')
  15. COE_CONT_RE = re.compile(r'^(.*)')
  16. def _match_character_regex(line, comedy_of_errors=False):
  17. return (COE_CHARACTER_RE.match(line) if comedy_of_errors
  18. else CHARACTER_RE.match(line))
  19. def _match_continuation_regex(line, comedy_of_errors=False):
  20. return (
  21. COE_CONT_RE.match(line) if comedy_of_errors else CONT_RE.match(line))
  22. def _split_into_plays(shakespeare_full):
  23. """Splits the full data by play."""
  24. # List of tuples (play_name, dict from character to list of lines)
  25. plays = []
  26. discarded_lines = [] # Track discarded lines.
  27. slines = shakespeare_full.splitlines(True)[1:]
  28. # skip contents, the sonnets, and all's well that ends well
  29. author_count = 0
  30. start_i = 0
  31. for i, l in enumerate(slines):
  32. if 'by William Shakespeare' in l:
  33. author_count += 1
  34. if author_count == 2:
  35. start_i = i - 5
  36. break
  37. slines = slines[start_i:]
  38. current_character = None
  39. comedy_of_errors = False
  40. for i, line in enumerate(slines):
  41. # This marks the end of the plays in the file.
  42. if i > 124195 - start_i:
  43. break
  44. # This is a pretty good heuristic for detecting the start of a new play:
  45. if 'by William Shakespeare' in line:
  46. current_character = None
  47. characters = collections.defaultdict(list)
  48. # The title will be 2, 3, 4, 5, 6, or 7 lines above "by William Shakespeare".
  49. if slines[i - 2].strip():
  50. title = slines[i - 2]
  51. elif slines[i - 3].strip():
  52. title = slines[i - 3]
  53. elif slines[i - 4].strip():
  54. title = slines[i - 4]
  55. elif slines[i - 5].strip():
  56. title = slines[i - 5]
  57. elif slines[i - 6].strip():
  58. title = slines[i - 6]
  59. else:
  60. title = slines[i - 7]
  61. title = title.strip()
  62. assert title, ('Parsing error on line %d. Expecting title 2 or 3 lines above.' % i)
  63. comedy_of_errors = (title == 'THE COMEDY OF ERRORS')
  64. # Degenerate plays are removed at the end of the method.
  65. plays.append((title, characters))
  66. continue
  67. match = _match_character_regex(line, comedy_of_errors)
  68. if match:
  69. character, snippet = match.group(1), match.group(2)
  70. # Some character names are written with multiple casings, e.g., SIR_Toby
  71. # and SIR_TOBY. To normalize the character names, we uppercase each name.
  72. # Note that this was not done in the original preprocessing and is a
  73. # recent fix.
  74. character = character.upper()
  75. if not (comedy_of_errors and character.startswith('ACT ')):
  76. characters[character].append(snippet)
  77. current_character = character
  78. continue
  79. else:
  80. current_character = None
  81. continue
  82. elif current_character:
  83. match = _match_continuation_regex(line, comedy_of_errors)
  84. if match:
  85. if comedy_of_errors and match.group(1).startswith('<'):
  86. current_character = None
  87. continue
  88. else:
  89. characters[current_character].append(match.group(1))
  90. continue
  91. # Didn't consume the line.
  92. line = line.strip()
  93. if line and i > 2646:
  94. # Before 2646 are the sonnets, which we expect to discard.
  95. discarded_lines.append('%d:%s' % (i, line))
  96. # Remove degenerate "plays".
  97. return [play for play in plays if len(play[1]) > 1], discarded_lines
  98. def _remove_nonalphanumerics(filename):
  99. return re.sub('\\W+', '_', filename)
  100. def play_and_character(play, character):
  101. return _remove_nonalphanumerics((play + '_' + character).replace(' ', '_'))
  102. def _get_train_test_by_character(plays, test_fraction=0.2):
  103. """
  104. Splits character data into train and test sets.
  105. if test_fraction <= 0, returns {} for all_test_examples
  106. plays := list of (play, dict) tuples where play is a string and dict
  107. is a dictionary with character names as keys
  108. """
  109. skipped_characters = 0
  110. all_train_examples = collections.defaultdict(list)
  111. all_test_examples = collections.defaultdict(list)
  112. def add_examples(example_dict, example_tuple_list):
  113. for play, character, sound_bite in example_tuple_list:
  114. example_dict[play_and_character(
  115. play, character)].append(sound_bite)
  116. users_and_plays = {}
  117. for play, characters in plays:
  118. curr_characters = list(characters.keys())
  119. for c in curr_characters:
  120. users_and_plays[play_and_character(play, c)] = play
  121. for character, sound_bites in characters.items():
  122. examples = [(play, character, sound_bite)
  123. for sound_bite in sound_bites]
  124. if len(examples) <= 2:
  125. skipped_characters += 1
  126. # Skip characters with fewer than 2 lines since we need at least one
  127. # train and one test line.
  128. continue
  129. train_examples = examples
  130. if test_fraction > 0:
  131. num_test = max(int(len(examples) * test_fraction), 1)
  132. train_examples = examples[:-num_test]
  133. test_examples = examples[-num_test:]
  134. assert len(test_examples) == num_test
  135. assert len(train_examples) >= len(test_examples)
  136. add_examples(all_test_examples, test_examples)
  137. add_examples(all_train_examples, train_examples)
  138. return users_and_plays, all_train_examples, all_test_examples
  139. def _write_data_by_character(examples, output_directory):
  140. """Writes a collection of data files by play & character."""
  141. if not os.path.exists(output_directory):
  142. os.makedirs(output_directory)
  143. for character_name, sound_bites in examples.items():
  144. filename = os.path.join(output_directory, character_name + '.txt')
  145. with open(filename, 'w') as output:
  146. for sound_bite in sound_bites:
  147. output.write(sound_bite + '\n')
  148. def shakespeare_preprocess(input_filename, output_directory):
  149. print('Splitting .txt data between users')
  150. input_filename = input_filename
  151. with open(input_filename, 'r') as input_file:
  152. shakespeare_full = input_file.read()
  153. plays, discarded_lines = _split_into_plays(shakespeare_full)
  154. print('Discarded %d lines' % len(discarded_lines))
  155. users_and_plays, all_examples, _ = _get_train_test_by_character(plays, test_fraction=-1.0)
  156. with open(os.path.join(output_directory, 'users_and_plays.json'), 'w') as ouf:
  157. json.dump(users_and_plays, ouf)
  158. _write_data_by_character(all_examples,
  159. os.path.join(output_directory,
  160. 'by_play_and_character/'))