git_command.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. #
  2. # Copyright (C) 2008 The Android Open Source Project
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from __future__ import print_function
  16. import fcntl
  17. import os
  18. import select
  19. import sys
  20. import subprocess
  21. import tempfile
  22. from signal import SIGTERM
  23. from error import GitError
  24. from trace import REPO_TRACE, IsTrace, Trace
  25. from wrapper import Wrapper
  26. GIT = 'git'
  27. MIN_GIT_VERSION = (1, 5, 4)
  28. GIT_DIR = 'GIT_DIR'
  29. LAST_GITDIR = None
  30. LAST_CWD = None
  31. _ssh_proxy_path = None
  32. _ssh_sock_path = None
  33. _ssh_clients = []
  34. def ssh_sock(create=True):
  35. global _ssh_sock_path
  36. if _ssh_sock_path is None:
  37. if not create:
  38. return None
  39. tmp_dir = '/tmp'
  40. if not os.path.exists(tmp_dir):
  41. tmp_dir = tempfile.gettempdir()
  42. _ssh_sock_path = os.path.join(
  43. tempfile.mkdtemp('', 'ssh-', tmp_dir),
  44. 'master-%r@%h:%p')
  45. return _ssh_sock_path
  46. def _ssh_proxy():
  47. global _ssh_proxy_path
  48. if _ssh_proxy_path is None:
  49. _ssh_proxy_path = os.path.join(
  50. os.path.dirname(__file__),
  51. 'git_ssh')
  52. return _ssh_proxy_path
  53. def _add_ssh_client(p):
  54. _ssh_clients.append(p)
  55. def _remove_ssh_client(p):
  56. try:
  57. _ssh_clients.remove(p)
  58. except ValueError:
  59. pass
  60. def terminate_ssh_clients():
  61. global _ssh_clients
  62. for p in _ssh_clients:
  63. try:
  64. os.kill(p.pid, SIGTERM)
  65. p.wait()
  66. except OSError:
  67. pass
  68. _ssh_clients = []
  69. _git_version = None
  70. class _sfd(object):
  71. """select file descriptor class"""
  72. def __init__(self, fd, dest, std_name):
  73. assert std_name in ('stdout', 'stderr')
  74. self.fd = fd
  75. self.dest = dest
  76. self.std_name = std_name
  77. def fileno(self):
  78. return self.fd.fileno()
  79. class _GitCall(object):
  80. def version(self):
  81. p = GitCommand(None, ['--version'], capture_stdout=True)
  82. if p.Wait() == 0:
  83. if hasattr(p.stdout, 'decode'):
  84. return p.stdout.decode('utf-8')
  85. else:
  86. return p.stdout
  87. return None
  88. def version_tuple(self):
  89. global _git_version
  90. if _git_version is None:
  91. ver_str = git.version()
  92. _git_version = Wrapper().ParseGitVersion(ver_str)
  93. if _git_version is None:
  94. print('fatal: "%s" unsupported' % ver_str, file=sys.stderr)
  95. sys.exit(1)
  96. return _git_version
  97. def __getattr__(self, name):
  98. name = name.replace('_','-')
  99. def fun(*cmdv):
  100. command = [name]
  101. command.extend(cmdv)
  102. return GitCommand(None, command).Wait() == 0
  103. return fun
  104. git = _GitCall()
  105. def git_require(min_version, fail=False):
  106. git_version = git.version_tuple()
  107. if min_version <= git_version:
  108. return True
  109. if fail:
  110. need = '.'.join(map(str, min_version))
  111. print('fatal: git %s or later required' % need, file=sys.stderr)
  112. sys.exit(1)
  113. return False
  114. def _setenv(env, name, value):
  115. env[name] = value.encode()
  116. class GitCommand(object):
  117. def __init__(self,
  118. project,
  119. cmdv,
  120. bare = False,
  121. provide_stdin = False,
  122. capture_stdout = False,
  123. capture_stderr = False,
  124. disable_editor = False,
  125. ssh_proxy = False,
  126. cwd = None,
  127. gitdir = None):
  128. env = os.environ.copy()
  129. for key in [REPO_TRACE,
  130. GIT_DIR,
  131. 'GIT_ALTERNATE_OBJECT_DIRECTORIES',
  132. 'GIT_OBJECT_DIRECTORY',
  133. 'GIT_WORK_TREE',
  134. 'GIT_GRAFT_FILE',
  135. 'GIT_INDEX_FILE']:
  136. if key in env:
  137. del env[key]
  138. # If we are not capturing std* then need to print it.
  139. self.tee = {'stdout': not capture_stdout, 'stderr': not capture_stderr}
  140. if disable_editor:
  141. _setenv(env, 'GIT_EDITOR', ':')
  142. if ssh_proxy:
  143. _setenv(env, 'REPO_SSH_SOCK', ssh_sock())
  144. _setenv(env, 'GIT_SSH', _ssh_proxy())
  145. if 'http_proxy' in env and 'darwin' == sys.platform:
  146. s = "'http.proxy=%s'" % (env['http_proxy'],)
  147. p = env.get('GIT_CONFIG_PARAMETERS')
  148. if p is not None:
  149. s = p + ' ' + s
  150. _setenv(env, 'GIT_CONFIG_PARAMETERS', s)
  151. if 'GIT_ALLOW_PROTOCOL' not in env:
  152. _setenv(env, 'GIT_ALLOW_PROTOCOL',
  153. 'file:git:http:https:ssh:persistent-http:persistent-https:sso:rpc')
  154. if project:
  155. if not cwd:
  156. cwd = project.worktree
  157. if not gitdir:
  158. gitdir = project.gitdir
  159. command = [GIT]
  160. if bare:
  161. if gitdir:
  162. _setenv(env, GIT_DIR, gitdir)
  163. cwd = None
  164. command.append(cmdv[0])
  165. # Need to use the --progress flag for fetch/clone so output will be
  166. # displayed as by default git only does progress output if stderr is a TTY.
  167. if sys.stderr.isatty() and cmdv[0] in ('fetch', 'clone'):
  168. if '--progress' not in cmdv and '--quiet' not in cmdv:
  169. command.append('--progress')
  170. command.extend(cmdv[1:])
  171. if provide_stdin:
  172. stdin = subprocess.PIPE
  173. else:
  174. stdin = None
  175. stdout = subprocess.PIPE
  176. stderr = subprocess.PIPE
  177. if IsTrace():
  178. global LAST_CWD
  179. global LAST_GITDIR
  180. dbg = ''
  181. if cwd and LAST_CWD != cwd:
  182. if LAST_GITDIR or LAST_CWD:
  183. dbg += '\n'
  184. dbg += ': cd %s\n' % cwd
  185. LAST_CWD = cwd
  186. if GIT_DIR in env and LAST_GITDIR != env[GIT_DIR]:
  187. if LAST_GITDIR or LAST_CWD:
  188. dbg += '\n'
  189. dbg += ': export GIT_DIR=%s\n' % env[GIT_DIR]
  190. LAST_GITDIR = env[GIT_DIR]
  191. dbg += ': '
  192. dbg += ' '.join(command)
  193. if stdin == subprocess.PIPE:
  194. dbg += ' 0<|'
  195. if stdout == subprocess.PIPE:
  196. dbg += ' 1>|'
  197. if stderr == subprocess.PIPE:
  198. dbg += ' 2>|'
  199. Trace('%s', dbg)
  200. try:
  201. p = subprocess.Popen(command,
  202. cwd = cwd,
  203. env = env,
  204. stdin = stdin,
  205. stdout = stdout,
  206. stderr = stderr)
  207. except Exception as e:
  208. raise GitError('%s: %s' % (command[1], e))
  209. if ssh_proxy:
  210. _add_ssh_client(p)
  211. self.process = p
  212. self.stdin = p.stdin
  213. def Wait(self):
  214. try:
  215. p = self.process
  216. rc = self._CaptureOutput()
  217. finally:
  218. _remove_ssh_client(p)
  219. return rc
  220. def _CaptureOutput(self):
  221. p = self.process
  222. s_in = [_sfd(p.stdout, sys.stdout, 'stdout'),
  223. _sfd(p.stderr, sys.stderr, 'stderr')]
  224. self.stdout = ''
  225. self.stderr = ''
  226. for s in s_in:
  227. flags = fcntl.fcntl(s.fd, fcntl.F_GETFL)
  228. fcntl.fcntl(s.fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
  229. while s_in:
  230. in_ready, _, _ = select.select(s_in, [], [])
  231. for s in in_ready:
  232. buf = s.fd.read(4096)
  233. if not buf:
  234. s_in.remove(s)
  235. continue
  236. if not hasattr(buf, 'encode'):
  237. buf = buf.decode()
  238. if s.std_name == 'stdout':
  239. self.stdout += buf
  240. else:
  241. self.stderr += buf
  242. if self.tee[s.std_name]:
  243. s.dest.write(buf)
  244. s.dest.flush()
  245. return p.wait()