Jelajahi Sumber

Some fixes for supporting python3

* Fix imports.
* Use python3 syntax.
* Wrap map() calls with list().
* Use list() only wherever needed.
  (Thanks Conley!)
* Fix dictionary iteration methods
  (s/iteritems/items/).
* Make use of sorted() in appropriate places
* Use iterators directly in the loop.
* Don't use .keys() wherever it isn't needed.
* Use sys.maxsize instead of sys.maxint

TODO:
* Make repo work fully with python3. :)

Some of this was done by the '2to3' tool [1], by
applying the needed fixes in a way that doesn't
break compatibility with python2.

Links:
[1]: http://docs.python.org/2/library/2to3.html

Change-Id: Ibdf3bf9a530d716db905733cb9bfef83a48820f7
Signed-off-by: Chirayu Desai <cdesai@cyanogenmod.org>
Chirayu Desai 13 tahun lalu
induk
melakukan
217ea7d274
14 mengubah file dengan 130 tambahan dan 90 penghapusan
  1. 1 1
      command.py
  2. 9 8
      git_config.py
  3. 7 2
      git_refs.py
  4. 6 1
      main.py
  5. 32 32
      manifest_xml.py
  6. 28 21
      project.py
  7. 2 2
      subcmds/__init__.py
  8. 2 3
      subcmds/branches.py
  9. 4 6
      subcmds/help.py
  10. 1 1
      subcmds/info.py
  11. 1 1
      subcmds/overview.py
  12. 8 3
      subcmds/status.py
  13. 23 8
      subcmds/sync.py
  14. 6 1
      subcmds/upload.py

+ 1 - 1
command.py

@@ -140,7 +140,7 @@ class Command(object):
     groups = [x for x in re.split(r'[,\s]+', groups) if x]
 
     if not args:
-      all_projects_list = all_projects.values()
+      all_projects_list = list(all_projects.values())
       derived_projects = {}
       for project in all_projects_list:
         if submodules_ok or project.sync_s:

+ 9 - 8
git_config.py

@@ -14,8 +14,9 @@
 # limitations under the License.
 
 from __future__ import print_function
-import cPickle
+
 import os
+import pickle
 import re
 import subprocess
 import sys
@@ -262,7 +263,7 @@ class GitConfig(object):
       Trace(': unpickle %s', self.file)
       fd = open(self._pickle, 'rb')
       try:
-        return cPickle.load(fd)
+        return pickle.load(fd)
       finally:
         fd.close()
     except EOFError:
@@ -271,7 +272,7 @@ class GitConfig(object):
     except IOError:
       os.remove(self._pickle)
       return None
-    except cPickle.PickleError:
+    except pickle.PickleError:
       os.remove(self._pickle)
       return None
 
@@ -279,13 +280,13 @@ class GitConfig(object):
     try:
       fd = open(self._pickle, 'wb')
       try:
-        cPickle.dump(cache, fd, cPickle.HIGHEST_PROTOCOL)
+        pickle.dump(cache, fd, pickle.HIGHEST_PROTOCOL)
       finally:
         fd.close()
     except IOError:
       if os.path.exists(self._pickle):
         os.remove(self._pickle)
-    except cPickle.PickleError:
+    except pickle.PickleError:
       if os.path.exists(self._pickle):
         os.remove(self._pickle)
 
@@ -537,8 +538,8 @@ class Remote(object):
     self.url = self._Get('url')
     self.review = self._Get('review')
     self.projectname = self._Get('projectname')
-    self.fetch = map(RefSpec.FromString,
-                     self._Get('fetch', all_keys=True))
+    self.fetch = list(map(RefSpec.FromString,
+                      self._Get('fetch', all_keys=True)))
     self._review_url = None
 
   def _InsteadOf(self):
@@ -657,7 +658,7 @@ class Remote(object):
     self._Set('url', self.url)
     self._Set('review', self.review)
     self._Set('projectname', self.projectname)
-    self._Set('fetch', map(str, self.fetch))
+    self._Set('fetch', list(map(str, self.fetch)))
 
   def _Set(self, key, value):
     key = 'remote.%s.%s' % (self.name, key)

+ 7 - 2
git_refs.py

@@ -66,7 +66,7 @@ class GitRefs(object):
   def _NeedUpdate(self):
     Trace(': scan refs %s', self._gitdir)
 
-    for name, mtime in self._mtime.iteritems():
+    for name, mtime in self._mtime.items():
       try:
         if mtime != os.path.getmtime(os.path.join(self._gitdir, name)):
           return True
@@ -89,7 +89,7 @@ class GitRefs(object):
     attempts = 0
     while scan and attempts < 5:
       scan_next = {}
-      for name, dest in scan.iteritems():
+      for name, dest in scan.items():
         if dest in self._phyref:
           self._phyref[name] = self._phyref[dest]
         else:
@@ -108,6 +108,7 @@ class GitRefs(object):
       return
     try:
       for line in fd:
+        line = str(line)
         if line[0] == '#':
           continue
         if line[0] == '^':
@@ -150,6 +151,10 @@ class GitRefs(object):
     finally:
       fd.close()
 
+    try:
+      ref_id = ref_id.decode()
+    except AttributeError:
+      pass
     if not ref_id:
       return
     ref_id = ref_id[:-1]

+ 6 - 1
main.py

@@ -50,6 +50,11 @@ from pager import RunPager
 
 from subcmds import all_commands
 
+try:
+  input = raw_input
+except NameError:
+  pass
+
 global_options = optparse.OptionParser(
                  usage="repo [-p|--paginate|--no-pager] COMMAND [ARGS]"
                  )
@@ -286,7 +291,7 @@ def _AddPasswordFromUserInput(handler, msg, req):
   if user is None:
     print(msg)
     try:
-      user = raw_input('User: ')
+      user = input('User: ')
       password = getpass.getpass()
     except KeyboardInterrupt:
       return

+ 32 - 32
manifest_xml.py

@@ -18,7 +18,15 @@ import itertools
 import os
 import re
 import sys
-import urlparse
+try:
+  # For python3
+  import urllib.parse
+except ImportError:
+  # For python2
+  import imp
+  import urlparse
+  urllib = imp.new_module('urllib')
+  urllib.parse = urlparse
 import xml.dom.minidom
 
 from git_config import GitConfig
@@ -30,8 +38,8 @@ MANIFEST_FILE_NAME = 'manifest.xml'
 LOCAL_MANIFEST_NAME = 'local_manifest.xml'
 LOCAL_MANIFESTS_DIR_NAME = 'local_manifests'
 
-urlparse.uses_relative.extend(['ssh', 'git'])
-urlparse.uses_netloc.extend(['ssh', 'git'])
+urllib.parse.uses_relative.extend(['ssh', 'git'])
+urllib.parse.uses_netloc.extend(['ssh', 'git'])
 
 class _Default(object):
   """Project defaults within the manifest."""
@@ -73,7 +81,7 @@ class _XmlRemote(object):
     # ie, if manifestUrl is of the form <hostname:port>
     if manifestUrl.find(':') != manifestUrl.find('/') - 1:
       manifestUrl = 'gopher://' + manifestUrl
-    url = urlparse.urljoin(manifestUrl, url)
+    url = urllib.parse.urljoin(manifestUrl, url)
     url = re.sub(r'^gopher://', '', url)
     if p:
       url = 'persistent-' + url
@@ -162,10 +170,8 @@ class XmlManifest(object):
       notice_element.appendChild(doc.createTextNode(indented_notice))
 
     d = self.default
-    sort_remotes = list(self.remotes.keys())
-    sort_remotes.sort()
 
-    for r in sort_remotes:
+    for r in sorted(self.remotes):
       self._RemoteToXml(self.remotes[r], doc, root)
     if self.remotes:
       root.appendChild(doc.createTextNode(''))
@@ -257,12 +263,11 @@ class XmlManifest(object):
         e.setAttribute('sync-s', 'true')
 
       if p.subprojects:
-        sort_projects = [subp.name for subp in p.subprojects]
-        sort_projects.sort()
+        sort_projects = list(sorted([subp.name for subp in p.subprojects]))
         output_projects(p, e, sort_projects)
 
-    sort_projects = [key for key in self.projects.keys()
-                     if not self.projects[key].parent]
+    sort_projects = list(sorted([key for key, value in self.projects.items()
+                     if not value.parent]))
     sort_projects.sort()
     output_projects(None, root, sort_projects)
 
@@ -386,9 +391,8 @@ class XmlManifest(object):
         name = self._reqatt(node, 'name')
         fp = os.path.join(include_root, name)
         if not os.path.isfile(fp):
-          raise ManifestParseError, \
-              "include %s doesn't exist or isn't a file" % \
-              (name,)
+          raise ManifestParseError("include %s doesn't exist or isn't a file"
+              % (name,))
         try:
           nodes.extend(self._ParseManifestXml(fp, include_root))
         # should isolate this to the exact exception, but that's
@@ -494,7 +498,7 @@ class XmlManifest(object):
     name = None
     m_url = m.GetRemote(m.remote.name).url
     if m_url.endswith('/.git'):
-      raise ManifestParseError, 'refusing to mirror %s' % m_url
+      raise ManifestParseError('refusing to mirror %s' % m_url)
 
     if self._default and self._default.remote:
       url = self._default.remote.resolvedFetchUrl
@@ -588,7 +592,7 @@ class XmlManifest(object):
 
     # Figure out minimum indentation, skipping the first line (the same line
     # as the <notice> tag)...
-    minIndent = sys.maxint
+    minIndent = sys.maxsize
     lines = notice.splitlines()
     for line in lines[1:]:
       lstrippedLine = line.lstrip()
@@ -627,25 +631,22 @@ class XmlManifest(object):
     if remote is None:
       remote = self._default.remote
     if remote is None:
-      raise ManifestParseError, \
-            "no remote for project %s within %s" % \
-            (name, self.manifestFile)
+      raise ManifestParseError("no remote for project %s within %s" %
+            (name, self.manifestFile))
 
     revisionExpr = node.getAttribute('revision')
     if not revisionExpr:
       revisionExpr = self._default.revisionExpr
     if not revisionExpr:
-      raise ManifestParseError, \
-            "no revision for project %s within %s" % \
-            (name, self.manifestFile)
+      raise ManifestParseError("no revision for project %s within %s" %
+            (name, self.manifestFile))
 
     path = node.getAttribute('path')
     if not path:
       path = name
     if path.startswith('/'):
-      raise ManifestParseError, \
-            "project %s path cannot be absolute in %s" % \
-            (name, self.manifestFile)
+      raise ManifestParseError("project %s path cannot be absolute in %s" %
+            (name, self.manifestFile))
 
     rebase = node.getAttribute('rebase')
     if not rebase:
@@ -764,7 +765,8 @@ class XmlManifest(object):
     except ManifestParseError:
       keep = "true"
     if keep != "true" and keep != "false":
-      raise ManifestParseError, "optional \"keep\" attribute must be \"true\" or \"false\""
+      raise ManifestParseError('optional "keep" attribute must be '
+            '"true" or "false"')
     project.AddAnnotation(name, value, keep)
 
   def _get_remote(self, node):
@@ -774,9 +776,8 @@ class XmlManifest(object):
 
     v = self._remotes.get(name)
     if not v:
-      raise ManifestParseError, \
-            "remote %s not defined in %s" % \
-            (name, self.manifestFile)
+      raise ManifestParseError("remote %s not defined in %s" %
+            (name, self.manifestFile))
     return v
 
   def _reqatt(self, node, attname):
@@ -785,7 +786,6 @@ class XmlManifest(object):
     """
     v = node.getAttribute(attname)
     if not v:
-      raise ManifestParseError, \
-            "no %s in <%s> within %s" % \
-            (attname, node.nodeName, self.manifestFile)
+      raise ManifestParseError("no %s in <%s> within %s" %
+            (attname, node.nodeName, self.manifestFile))
     return v

+ 28 - 21
project.py

@@ -36,6 +36,11 @@ from trace import IsTrace, Trace
 
 from git_refs import GitRefs, HEAD, R_HEADS, R_TAGS, R_PUB, R_M
 
+try:
+  input = raw_input
+except NameError:
+  pass
+
 def _lwrite(path, content):
   lock = '%s.lock' % path
 
@@ -78,7 +83,7 @@ def _ProjectHooks():
   if _project_hook_list is None:
     d = os.path.abspath(os.path.dirname(__file__))
     d = os.path.join(d , 'hooks')
-    _project_hook_list = map(lambda x: os.path.join(d, x), os.listdir(d))
+    _project_hook_list = [os.path.join(d, x) for x in os.listdir(d)]
   return _project_hook_list
 
 
@@ -361,7 +366,7 @@ class RepoHook(object):
                  'Do you want to allow this script to run '
                  '(yes/yes-never-ask-again/NO)? ') % (
                  self._GetMustVerb(), self._script_fullpath)
-      response = raw_input(prompt).lower()
+      response = input(prompt).lower()
       print()
 
       # User is doing a one-time approval.
@@ -646,7 +651,7 @@ class Project(object):
     all_refs = self._allrefs
     heads = {}
 
-    for name, ref_id in all_refs.iteritems():
+    for name, ref_id in all_refs.items():
       if name.startswith(R_HEADS):
         name = name[len(R_HEADS):]
         b = self.GetBranch(name)
@@ -655,7 +660,7 @@ class Project(object):
         b.revision = ref_id
         heads[name] = b
 
-    for name, ref_id in all_refs.iteritems():
+    for name, ref_id in all_refs.items():
       if name.startswith(R_PUB):
         name = name[len(R_PUB):]
         b = heads.get(name)
@@ -761,10 +766,7 @@ class Project(object):
     paths.extend(df.keys())
     paths.extend(do)
 
-    paths = list(set(paths))
-    paths.sort()
-
-    for p in paths:
+    for p in sorted(set(paths)):
       try:
         i = di[p]
       except KeyError:
@@ -856,13 +858,13 @@ class Project(object):
       all_refs = self._allrefs
     heads = set()
     canrm = {}
-    for name, ref_id in all_refs.iteritems():
+    for name, ref_id in all_refs.items():
       if name.startswith(R_HEADS):
         heads.add(name)
       elif name.startswith(R_PUB):
         canrm[name] = ref_id
 
-    for name, ref_id in canrm.iteritems():
+    for name, ref_id in canrm.items():
       n = name[len(R_PUB):]
       if R_HEADS + n not in heads:
         self.bare_git.DeleteRef(name, ref_id)
@@ -873,14 +875,14 @@ class Project(object):
     heads = {}
     pubed = {}
 
-    for name, ref_id in self._allrefs.iteritems():
+    for name, ref_id in self._allrefs.items():
       if name.startswith(R_HEADS):
         heads[name[len(R_HEADS):]] = ref_id
       elif name.startswith(R_PUB):
         pubed[name[len(R_PUB):]] = ref_id
 
     ready = []
-    for branch, ref_id in heads.iteritems():
+    for branch, ref_id in heads.items():
       if branch in pubed and pubed[branch] == ref_id:
         continue
       if selected_branch and branch != selected_branch:
@@ -1223,7 +1225,7 @@ class Project(object):
     cmd = ['fetch', remote.name]
     cmd.append('refs/changes/%2.2d/%d/%d' \
                % (change_id % 100, change_id, patch_id))
-    cmd.extend(map(str, remote.fetch))
+    cmd.extend(list(map(str, remote.fetch)))
     if GitCommand(self, cmd, bare=True).Wait() != 0:
       return None
     return DownloadedChange(self,
@@ -1612,7 +1614,7 @@ class Project(object):
         ids = set(all_refs.values())
         tmp = set()
 
-        for r, ref_id in GitRefs(ref_dir).all.iteritems():
+        for r, ref_id in GitRefs(ref_dir).all.items():
           if r not in all_refs:
             if r.startswith(R_TAGS) or remote.WritesTo(r):
               all_refs[r] = ref_id
@@ -1627,13 +1629,10 @@ class Project(object):
           ids.add(ref_id)
           tmp.add(r)
 
-        ref_names = list(all_refs.keys())
-        ref_names.sort()
-
         tmp_packed = ''
         old_packed = ''
 
-        for r in ref_names:
+        for r in sorted(all_refs):
           line = '%s %s\n' % (all_refs[r], r)
           tmp_packed += line
           if r not in tmp:
@@ -1666,7 +1665,7 @@ class Project(object):
         cmd.append('--no-tags')
       else:
         cmd.append('--tags')
-      cmd.append((u'+refs/heads/*:') + remote.ToLocal('refs/heads/*'))
+      cmd.append(str((u'+refs/heads/*:') + remote.ToLocal('refs/heads/*')))
     elif tag_name is not None:
       cmd.append('tag')
       cmd.append(tag_name)
@@ -1676,7 +1675,7 @@ class Project(object):
         branch = self.upstream
       if branch.startswith(R_HEADS):
         branch = branch[len(R_HEADS):]
-      cmd.append((u'+refs/heads/%s:' % branch) + remote.ToLocal('refs/heads/%s' % branch))
+      cmd.append(str((u'+refs/heads/%s:' % branch) + remote.ToLocal('refs/heads/%s' % branch)))
 
     ok = False
     for _i in range(2):
@@ -2102,6 +2101,10 @@ class Project(object):
         line = fd.read()
       finally:
         fd.close()
+      try:
+        line = line.decode()
+      except AttributeError:
+        pass
       if line.startswith('ref: '):
         return line[5:-1]
       return line[:-1]
@@ -2195,7 +2198,7 @@ class Project(object):
           if not git_require((1, 7, 2)):
             raise ValueError('cannot set config on command line for %s()'
                              % name)
-          for k, v in config.iteritems():
+          for k, v in config.items():
             cmdv.append('-c')
             cmdv.append('%s=%s' % (k, v))
         cmdv.append(name)
@@ -2211,6 +2214,10 @@ class Project(object):
                          name,
                          p.stderr))
         r = p.stdout
+        try:
+          r = r.decode()
+        except AttributeError:
+          pass
         if r.endswith('\n') and r.index('\n') == len(r) - 1:
           return r[:-1]
         return r

+ 2 - 2
subcmds/__init__.py

@@ -38,8 +38,8 @@ for py in os.listdir(my_dir):
     try:
       cmd = getattr(mod, clsn)()
     except AttributeError:
-      raise SyntaxError, '%s/%s does not define class %s' % (
-                         __name__, py, clsn)
+      raise SyntaxError('%s/%s does not define class %s' % (
+                         __name__, py, clsn))
 
     name = name.replace('_', '-')
     cmd.NAME = name

+ 2 - 3
subcmds/branches.py

@@ -98,14 +98,13 @@ is shown, then the branch appears in all projects.
     project_cnt = len(projects)
 
     for project in projects:
-      for name, b in project.GetBranches().iteritems():
+      for name, b in project.GetBranches().items():
         b.project = project
         if name not in all_branches:
           all_branches[name] = BranchInfo(name)
         all_branches[name].add(b)
 
-    names = all_branches.keys()
-    names.sort()
+    names = list(sorted(all_branches))
 
     if not names:
       print('   (no branches)', file=sys.stderr)

+ 4 - 6
subcmds/help.py

@@ -34,8 +34,7 @@ Displays detailed usage information about a command.
   def _PrintAllCommands(self):
     print('usage: repo COMMAND [ARGS]')
     print('The complete list of recognized repo commands are:')
-    commandNames = self.commands.keys()
-    commandNames.sort()
+    commandNames = list(sorted(self.commands))
 
     maxlen = 0
     for name in commandNames:
@@ -55,10 +54,9 @@ Displays detailed usage information about a command.
   def _PrintCommonCommands(self):
     print('usage: repo COMMAND [ARGS]')
     print('The most commonly used repo commands are:')
-    commandNames = [name
-                    for name in self.commands.keys()
-                    if self.commands[name].common]
-    commandNames.sort()
+    commandNames = list(sorted([name
+                    for name, command in self.commands.items()
+                    if command.common]))
 
     maxlen = 0
     for name in commandNames:

+ 1 - 1
subcmds/info.py

@@ -163,7 +163,7 @@ class Info(PagedCommand):
     all_branches = []
     for project in self.GetProjects(args):
       br = [project.GetUploadableBranch(x)
-            for x in project.GetBranches().keys()]
+            for x in project.GetBranches()]
       br = [x for x in br if x]
       if self.opt.current_branch:
         br = [x for x in br if x.name == project.CurrentBranch]

+ 1 - 1
subcmds/overview.py

@@ -42,7 +42,7 @@ are displayed.
     all_branches = []
     for project in self.GetProjects(args):
       br = [project.GetUploadableBranch(x)
-            for x in project.GetBranches().keys()]
+            for x in project.GetBranches()]
       br = [x for x in br if x]
       if opt.current_branch:
         br = [x for x in br if x.name == project.CurrentBranch]

+ 8 - 3
subcmds/status.py

@@ -21,10 +21,15 @@ except ImportError:
   import dummy_threading as _threading
 
 import glob
+try:
+  # For python2
+  import StringIO as io
+except ImportError:
+  # For python3
+  import io
 import itertools
 import os
 import sys
-import StringIO
 
 from color import Coloring
 
@@ -142,7 +147,7 @@ the following meanings:
       for project in all_projects:
         sem.acquire()
 
-        class BufList(StringIO.StringIO):
+        class BufList(io.StringIO):
           def dump(self, ostream):
             for entry in self.buflist:
               ostream.write(entry)
@@ -182,7 +187,7 @@ the following meanings:
       try:
         os.chdir(self.manifest.topdir)
 
-        outstring = StringIO.StringIO()
+        outstring = io.StringIO()
         self._FindOrphans(glob.glob('.*') + \
             glob.glob('*'), \
             proj_dirs, proj_dirs_parents, outstring)

+ 23 - 8
subcmds/sync.py

@@ -24,8 +24,24 @@ import socket
 import subprocess
 import sys
 import time
-import urlparse
-import xmlrpclib
+try:
+  # For python3
+  import urllib.parse
+except ImportError:
+  # For python2
+  import imp
+  import urlparse
+  urllib = imp.new_module('urllib')
+  urllib.parse = urlparse
+try:
+  # For python3
+  import xmlrpc.client
+except ImportError:
+  # For python2
+  import imp
+  import xmlrpclib
+  xmlrpc = imp.new_module('xmlrpc')
+  xmlrpc.client = xmlrpclib
 
 try:
   import threading as _threading
@@ -498,7 +514,7 @@ later is required to fix a server side protocol bug.
                   file=sys.stderr)
           else:
             try:
-              parse_result = urlparse.urlparse(manifest_server)
+              parse_result = urllib.parse(manifest_server)
               if parse_result.hostname:
                 username, _account, password = \
                   info.authenticators(parse_result.hostname)
@@ -516,7 +532,7 @@ later is required to fix a server side protocol bug.
                                                     1)
 
       try:
-        server = xmlrpclib.Server(manifest_server)
+        server = xmlrpc.client.Server(manifest_server)
         if opt.smart_sync:
           p = self.manifest.manifestProject
           b = p.GetBranch(p.CurrentBranch)
@@ -525,8 +541,7 @@ later is required to fix a server side protocol bug.
             branch = branch[len(R_HEADS):]
 
           env = os.environ.copy()
-          if (env.has_key('TARGET_PRODUCT') and
-              env.has_key('TARGET_BUILD_VARIANT')):
+          if 'TARGET_PRODUCT' in env and 'TARGET_BUILD_VARIANT' in env:
             target = '%s-%s' % (env['TARGET_PRODUCT'],
                                 env['TARGET_BUILD_VARIANT'])
             [success, manifest_str] = server.GetApprovedManifest(branch, target)
@@ -554,11 +569,11 @@ later is required to fix a server side protocol bug.
         else:
           print('error: %s' % manifest_str, file=sys.stderr)
           sys.exit(1)
-      except (socket.error, IOError, xmlrpclib.Fault) as e:
+      except (socket.error, IOError, xmlrpc.client.Fault) as e:
         print('error: cannot connect to manifest server %s:\n%s'
               % (self.manifest.manifest_server, e), file=sys.stderr)
         sys.exit(1)
-      except xmlrpclib.ProtocolError as e:
+      except xmlrpc.client.ProtocolError as e:
         print('error: cannot connect to manifest server %s:\n%d %s'
               % (self.manifest.manifest_server, e.errcode, e.errmsg),
               file=sys.stderr)

+ 6 - 1
subcmds/upload.py

@@ -23,6 +23,11 @@ from editor import Editor
 from error import HookError, UploadError
 from project import RepoHook
 
+try:
+  input = raw_input
+except NameError:
+  pass
+
 UNUSUAL_COMMIT_THRESHOLD = 5
 
 def _ConfirmManyUploads(multiple_branches=False):
@@ -33,7 +38,7 @@ def _ConfirmManyUploads(multiple_branches=False):
     print('ATTENTION: You are uploading an unusually high number of commits.')
   print('YOU PROBABLY DO NOT MEAN TO DO THIS. (Did you rebase across '
         'branches?)')
-  answer = raw_input("If you are sure you intend to do this, type 'yes': ").strip()
+  answer = input("If you are sure you intend to do this, type 'yes': ").strip()
   return answer == "yes"
 
 def _die(fmt, *args):