]> git.armaanb.net Git - phrases.git/blobdiff - phrases.py
make database-file location more flexible
[phrases.git] / phrases.py
index 63199d141d996cccc70d160ad3f71438f2909a32..43cc774e9266c5be8364f1977bddbede31a6aa77 100755 (executable)
@@ -6,6 +6,7 @@ import argparse
 import csv
 import random
 import sys
+import os.path
 
 def main(args=sys.argv[1:]):
     # Argument parsing
@@ -35,17 +36,26 @@ def main(args=sys.argv[1:]):
                         action='store_true',
                         help="print number of possibilities within constraints")
     parser.add_argument("-f", "--file",
-                        default="/usr/share/phrases/phrases.csv",
                         help="set the location of the phrase file")
     args = parser.parse_args()
 
     right_length = []
 
+    # find phrase file
+    if args.file:
+        phrase_file = args.file
+    if os.path.isfile("phrases.csv"):
+        phrase_file = "phrases.csv"
+    elif os.path.isfile("/usr/local/share/phrases/phrases.csv"):
+        phrase_file = "/usr/local/share/phrases/phrases.csv"
+    else:
+        sys.exit("cannot fine phrase database!")
+
     # convert csv file into list
-    with open(args.file) as f:
+    with open(phrase_file) as f:
         reader = csv.reader(f)
         next(reader, None) # skip header
-        all_lines = list(reader) 
+        all_lines = list(reader)
     f.close()
 
     # iterate through all the phrases
@@ -59,7 +69,7 @@ def main(args=sys.argv[1:]):
     try: # choose a random id from the shortlist
         chosen = int(right_length[random.randint(0, len(right_length) - 1)])
     except:
-        sys.exit("No phrase within the given parameters!")
+        sys.exit("no phrase within the given parameters!")
 
     # Output as specified in flags
     for row in all_lines:
@@ -70,6 +80,7 @@ def main(args=sys.argv[1:]):
                     or args.notes
                     or args.num):
                 print(row[1])
+                sys.exit(0)
             else:
                 if args.id:
                     print(row[1])
@@ -81,6 +92,7 @@ def main(args=sys.argv[1:]):
                     print(row[3])
                 if args.num:
                     print(len(right_length))
+                sys.exit(0)
 
 if __name__ == "__main__":
     main()