shake_utils.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. """
  2. Helper functions for preprocessing shakespeare data.
  3. These codes are adopted from LEAF with some modifications.
  4. """
  5. import json
  6. import os
  7. import re
  8. def __txt_to_data(txt_dir, seq_length=80):
  9. """Parses text file in given directory into data for next-character model.
  10. Args:
  11. txt_dir: path to text file
  12. seq_length: length of strings in X
  13. """
  14. raw_text = ""
  15. with open(txt_dir, 'r') as inf:
  16. raw_text = inf.read()
  17. raw_text = raw_text.replace('\n', ' ')
  18. raw_text = re.sub(r" *", r' ', raw_text)
  19. dataX = []
  20. dataY = []
  21. for i in range(0, len(raw_text) - seq_length, 1):
  22. seq_in = raw_text[i:i + seq_length]
  23. seq_out = raw_text[i + seq_length]
  24. dataX.append(seq_in)
  25. dataY.append(seq_out)
  26. return dataX, dataY
  27. def parse_data_in(data_dir, users_and_plays_path, raw=False):
  28. """
  29. returns dictionary with keys: users, num_samples, user_data
  30. raw := bool representing whether to include raw text in all_data
  31. if raw is True, then user_data key
  32. removes users with no data
  33. """
  34. with open(users_and_plays_path, 'r') as inf:
  35. users_and_plays = json.load(inf)
  36. files = os.listdir(data_dir)
  37. users = []
  38. hierarchies = []
  39. num_samples = []
  40. user_data = {}
  41. for f in files:
  42. user = f[:-4]
  43. passage = ''
  44. filename = os.path.join(data_dir, f)
  45. with open(filename, 'r') as inf:
  46. passage = inf.read()
  47. dataX, dataY = __txt_to_data(filename)
  48. if (len(dataX) > 0):
  49. users.append(user)
  50. if raw:
  51. user_data[user] = {'raw': passage}
  52. else:
  53. user_data[user] = {}
  54. user_data[user]['x'] = dataX
  55. user_data[user]['y'] = dataY
  56. hierarchies.append(users_and_plays[user])
  57. num_samples.append(len(dataY))
  58. all_data = {}
  59. all_data['users'] = users
  60. all_data['hierarchies'] = hierarchies
  61. all_data['num_samples'] = num_samples
  62. all_data['user_data'] = user_data
  63. return all_data