migration.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838
  1. import logging
  2. import sys
  3. from contextlib import contextmanager
  4. from sqlalchemy import MetaData, Table, Column, String, literal_column,\
  5. PrimaryKeyConstraint
  6. from sqlalchemy.engine.strategies import MockEngineStrategy
  7. from sqlalchemy.engine import url as sqla_url
  8. from sqlalchemy.engine import Connection
  9. from ..util.compat import callable, EncodedIO
  10. from .. import ddl, util
  11. log = logging.getLogger(__name__)
  12. class MigrationContext(object):
  13. """Represent the database state made available to a migration
  14. script.
  15. :class:`.MigrationContext` is the front end to an actual
  16. database connection, or alternatively a string output
  17. stream given a particular database dialect,
  18. from an Alembic perspective.
  19. When inside the ``env.py`` script, the :class:`.MigrationContext`
  20. is available via the
  21. :meth:`.EnvironmentContext.get_context` method,
  22. which is available at ``alembic.context``::
  23. # from within env.py script
  24. from alembic import context
  25. migration_context = context.get_context()
  26. For usage outside of an ``env.py`` script, such as for
  27. utility routines that want to check the current version
  28. in the database, the :meth:`.MigrationContext.configure`
  29. method to create new :class:`.MigrationContext` objects.
  30. For example, to get at the current revision in the
  31. database using :meth:`.MigrationContext.get_current_revision`::
  32. # in any application, outside of an env.py script
  33. from alembic.migration import MigrationContext
  34. from sqlalchemy import create_engine
  35. engine = create_engine("postgresql://mydatabase")
  36. conn = engine.connect()
  37. context = MigrationContext.configure(conn)
  38. current_rev = context.get_current_revision()
  39. The above context can also be used to produce
  40. Alembic migration operations with an :class:`.Operations`
  41. instance::
  42. # in any application, outside of the normal Alembic environment
  43. from alembic.operations import Operations
  44. op = Operations(context)
  45. op.alter_column("mytable", "somecolumn", nullable=True)
  46. """
  47. def __init__(self, dialect, connection, opts, environment_context=None):
  48. self.environment_context = environment_context
  49. self.opts = opts
  50. self.dialect = dialect
  51. self.script = opts.get('script')
  52. as_sql = opts.get('as_sql', False)
  53. transactional_ddl = opts.get("transactional_ddl")
  54. self._transaction_per_migration = opts.get(
  55. "transaction_per_migration", False)
  56. if as_sql:
  57. self.connection = self._stdout_connection(connection)
  58. assert self.connection is not None
  59. else:
  60. self.connection = connection
  61. self._migrations_fn = opts.get('fn')
  62. self.as_sql = as_sql
  63. if "output_encoding" in opts:
  64. self.output_buffer = EncodedIO(
  65. opts.get("output_buffer") or sys.stdout,
  66. opts['output_encoding']
  67. )
  68. else:
  69. self.output_buffer = opts.get("output_buffer", sys.stdout)
  70. self._user_compare_type = opts.get('compare_type', False)
  71. self._user_compare_server_default = opts.get(
  72. 'compare_server_default',
  73. False)
  74. self.version_table = version_table = opts.get(
  75. 'version_table', 'alembic_version')
  76. self.version_table_schema = version_table_schema = \
  77. opts.get('version_table_schema', None)
  78. self._version = Table(
  79. version_table, MetaData(),
  80. Column('version_num', String(32), nullable=False),
  81. schema=version_table_schema)
  82. if opts.get("version_table_pk", True):
  83. self._version.append_constraint(
  84. PrimaryKeyConstraint(
  85. 'version_num', name="%s_pkc" % version_table
  86. )
  87. )
  88. self._start_from_rev = opts.get("starting_rev")
  89. self.impl = ddl.DefaultImpl.get_by_dialect(dialect)(
  90. dialect, self.connection, self.as_sql,
  91. transactional_ddl,
  92. self.output_buffer,
  93. opts
  94. )
  95. log.info("Context impl %s.", self.impl.__class__.__name__)
  96. if self.as_sql:
  97. log.info("Generating static SQL")
  98. log.info("Will assume %s DDL.",
  99. "transactional" if self.impl.transactional_ddl
  100. else "non-transactional")
  101. @classmethod
  102. def configure(cls,
  103. connection=None,
  104. url=None,
  105. dialect_name=None,
  106. dialect=None,
  107. environment_context=None,
  108. opts=None,
  109. ):
  110. """Create a new :class:`.MigrationContext`.
  111. This is a factory method usually called
  112. by :meth:`.EnvironmentContext.configure`.
  113. :param connection: a :class:`~sqlalchemy.engine.Connection`
  114. to use for SQL execution in "online" mode. When present,
  115. is also used to determine the type of dialect in use.
  116. :param url: a string database url, or a
  117. :class:`sqlalchemy.engine.url.URL` object.
  118. The type of dialect to be used will be derived from this if
  119. ``connection`` is not passed.
  120. :param dialect_name: string name of a dialect, such as
  121. "postgresql", "mssql", etc. The type of dialect to be used will be
  122. derived from this if ``connection`` and ``url`` are not passed.
  123. :param opts: dictionary of options. Most other options
  124. accepted by :meth:`.EnvironmentContext.configure` are passed via
  125. this dictionary.
  126. """
  127. if opts is None:
  128. opts = {}
  129. if connection:
  130. if not isinstance(connection, Connection):
  131. util.warn(
  132. "'connection' argument to configure() is expected "
  133. "to be a sqlalchemy.engine.Connection instance, "
  134. "got %r" % connection)
  135. dialect = connection.dialect
  136. elif url:
  137. url = sqla_url.make_url(url)
  138. dialect = url.get_dialect()()
  139. elif dialect_name:
  140. url = sqla_url.make_url("%s://" % dialect_name)
  141. dialect = url.get_dialect()()
  142. elif not dialect:
  143. raise Exception("Connection, url, or dialect_name is required.")
  144. return MigrationContext(dialect, connection, opts, environment_context)
  145. def begin_transaction(self, _per_migration=False):
  146. transaction_now = _per_migration == self._transaction_per_migration
  147. if not transaction_now:
  148. @contextmanager
  149. def do_nothing():
  150. yield
  151. return do_nothing()
  152. elif not self.impl.transactional_ddl:
  153. @contextmanager
  154. def do_nothing():
  155. yield
  156. return do_nothing()
  157. elif self.as_sql:
  158. @contextmanager
  159. def begin_commit():
  160. self.impl.emit_begin()
  161. yield
  162. self.impl.emit_commit()
  163. return begin_commit()
  164. else:
  165. return self.bind.begin()
  166. def get_current_revision(self):
  167. """Return the current revision, usually that which is present
  168. in the ``alembic_version`` table in the database.
  169. This method intends to be used only for a migration stream that
  170. does not contain unmerged branches in the target database;
  171. if there are multiple branches present, an exception is raised.
  172. The :meth:`.MigrationContext.get_current_heads` should be preferred
  173. over this method going forward in order to be compatible with
  174. branch migration support.
  175. If this :class:`.MigrationContext` was configured in "offline"
  176. mode, that is with ``as_sql=True``, the ``starting_rev``
  177. parameter is returned instead, if any.
  178. """
  179. heads = self.get_current_heads()
  180. if len(heads) == 0:
  181. return None
  182. elif len(heads) > 1:
  183. raise util.CommandError(
  184. "Version table '%s' has more than one head present; "
  185. "please use get_current_heads()" % self.version_table)
  186. else:
  187. return heads[0]
  188. def get_current_heads(self):
  189. """Return a tuple of the current 'head versions' that are represented
  190. in the target database.
  191. For a migration stream without branches, this will be a single
  192. value, synonymous with that of
  193. :meth:`.MigrationContext.get_current_revision`. However when multiple
  194. unmerged branches exist within the target database, the returned tuple
  195. will contain a value for each head.
  196. If this :class:`.MigrationContext` was configured in "offline"
  197. mode, that is with ``as_sql=True``, the ``starting_rev``
  198. parameter is returned in a one-length tuple.
  199. If no version table is present, or if there are no revisions
  200. present, an empty tuple is returned.
  201. .. versionadded:: 0.7.0
  202. """
  203. if self.as_sql:
  204. start_from_rev = self._start_from_rev
  205. if start_from_rev == 'base':
  206. start_from_rev = None
  207. elif start_from_rev is not None and self.script:
  208. start_from_rev = \
  209. self.script.get_revision(start_from_rev).revision
  210. return util.to_tuple(start_from_rev, default=())
  211. else:
  212. if self._start_from_rev:
  213. raise util.CommandError(
  214. "Can't specify current_rev to context "
  215. "when using a database connection")
  216. if not self._has_version_table():
  217. return ()
  218. return tuple(
  219. row[0] for row in self.connection.execute(self._version.select())
  220. )
  221. def _ensure_version_table(self):
  222. self._version.create(self.connection, checkfirst=True)
  223. def _has_version_table(self):
  224. return self.connection.dialect.has_table(
  225. self.connection, self.version_table, self.version_table_schema)
  226. def stamp(self, script_directory, revision):
  227. """Stamp the version table with a specific revision.
  228. This method calculates those branches to which the given revision
  229. can apply, and updates those branches as though they were migrated
  230. towards that revision (either up or down). If no current branches
  231. include the revision, it is added as a new branch head.
  232. .. versionadded:: 0.7.0
  233. """
  234. heads = self.get_current_heads()
  235. if not self.as_sql and not heads:
  236. self._ensure_version_table()
  237. head_maintainer = HeadMaintainer(self, heads)
  238. for step in script_directory._stamp_revs(revision, heads):
  239. head_maintainer.update_to_step(step)
  240. def run_migrations(self, **kw):
  241. """Run the migration scripts established for this
  242. :class:`.MigrationContext`, if any.
  243. The commands in :mod:`alembic.command` will set up a function
  244. that is ultimately passed to the :class:`.MigrationContext`
  245. as the ``fn`` argument. This function represents the "work"
  246. that will be done when :meth:`.MigrationContext.run_migrations`
  247. is called, typically from within the ``env.py`` script of the
  248. migration environment. The "work function" then provides an iterable
  249. of version callables and other version information which
  250. in the case of the ``upgrade`` or ``downgrade`` commands are the
  251. list of version scripts to invoke. Other commands yield nothing,
  252. in the case that a command wants to run some other operation
  253. against the database such as the ``current`` or ``stamp`` commands.
  254. :param \**kw: keyword arguments here will be passed to each
  255. migration callable, that is the ``upgrade()`` or ``downgrade()``
  256. method within revision scripts.
  257. """
  258. self.impl.start_migrations()
  259. heads = self.get_current_heads()
  260. if not self.as_sql and not heads:
  261. self._ensure_version_table()
  262. head_maintainer = HeadMaintainer(self, heads)
  263. starting_in_transaction = not self.as_sql and \
  264. self._in_connection_transaction()
  265. for step in self._migrations_fn(heads, self):
  266. with self.begin_transaction(_per_migration=True):
  267. if self.as_sql and not head_maintainer.heads:
  268. # for offline mode, include a CREATE TABLE from
  269. # the base
  270. self._version.create(self.connection)
  271. log.info("Running %s", step)
  272. if self.as_sql:
  273. self.impl.static_output("-- Running %s" % (step.short_log,))
  274. step.migration_fn(**kw)
  275. # previously, we wouldn't stamp per migration
  276. # if we were in a transaction, however given the more
  277. # complex model that involves any number of inserts
  278. # and row-targeted updates and deletes, it's simpler for now
  279. # just to run the operations on every version
  280. head_maintainer.update_to_step(step)
  281. if not starting_in_transaction and not self.as_sql and \
  282. not self.impl.transactional_ddl and \
  283. self._in_connection_transaction():
  284. raise util.CommandError(
  285. "Migration \"%s\" has left an uncommitted "
  286. "transaction opened; transactional_ddl is False so "
  287. "Alembic is not committing transactions"
  288. % step)
  289. if self.as_sql and not head_maintainer.heads:
  290. self._version.drop(self.connection)
  291. def _in_connection_transaction(self):
  292. try:
  293. meth = self.connection.in_transaction
  294. except AttributeError:
  295. return False
  296. else:
  297. return meth()
  298. def execute(self, sql, execution_options=None):
  299. """Execute a SQL construct or string statement.
  300. The underlying execution mechanics are used, that is
  301. if this is "offline mode" the SQL is written to the
  302. output buffer, otherwise the SQL is emitted on
  303. the current SQLAlchemy connection.
  304. """
  305. self.impl._exec(sql, execution_options)
  306. def _stdout_connection(self, connection):
  307. def dump(construct, *multiparams, **params):
  308. self.impl._exec(construct)
  309. return MockEngineStrategy.MockConnection(self.dialect, dump)
  310. @property
  311. def bind(self):
  312. """Return the current "bind".
  313. In online mode, this is an instance of
  314. :class:`sqlalchemy.engine.Connection`, and is suitable
  315. for ad-hoc execution of any kind of usage described
  316. in :ref:`sqlexpression_toplevel` as well as
  317. for usage with the :meth:`sqlalchemy.schema.Table.create`
  318. and :meth:`sqlalchemy.schema.MetaData.create_all` methods
  319. of :class:`~sqlalchemy.schema.Table`,
  320. :class:`~sqlalchemy.schema.MetaData`.
  321. Note that when "standard output" mode is enabled,
  322. this bind will be a "mock" connection handler that cannot
  323. return results and is only appropriate for a very limited
  324. subset of commands.
  325. """
  326. return self.connection
  327. @property
  328. def config(self):
  329. """Return the :class:`.Config` used by the current environment, if any.
  330. .. versionadded:: 0.6.6
  331. """
  332. if self.environment_context:
  333. return self.environment_context.config
  334. else:
  335. return None
  336. def _compare_type(self, inspector_column, metadata_column):
  337. if self._user_compare_type is False:
  338. return False
  339. if callable(self._user_compare_type):
  340. user_value = self._user_compare_type(
  341. self,
  342. inspector_column,
  343. metadata_column,
  344. inspector_column.type,
  345. metadata_column.type
  346. )
  347. if user_value is not None:
  348. return user_value
  349. return self.impl.compare_type(
  350. inspector_column,
  351. metadata_column)
  352. def _compare_server_default(self, inspector_column,
  353. metadata_column,
  354. rendered_metadata_default,
  355. rendered_column_default):
  356. if self._user_compare_server_default is False:
  357. return False
  358. if callable(self._user_compare_server_default):
  359. user_value = self._user_compare_server_default(
  360. self,
  361. inspector_column,
  362. metadata_column,
  363. rendered_column_default,
  364. metadata_column.server_default,
  365. rendered_metadata_default
  366. )
  367. if user_value is not None:
  368. return user_value
  369. return self.impl.compare_server_default(
  370. inspector_column,
  371. metadata_column,
  372. rendered_metadata_default,
  373. rendered_column_default)
  374. class HeadMaintainer(object):
  375. def __init__(self, context, heads):
  376. self.context = context
  377. self.heads = set(heads)
  378. def _insert_version(self, version):
  379. assert version not in self.heads
  380. self.heads.add(version)
  381. self.context.impl._exec(
  382. self.context._version.insert().
  383. values(
  384. version_num=literal_column("'%s'" % version)
  385. )
  386. )
  387. def _delete_version(self, version):
  388. self.heads.remove(version)
  389. ret = self.context.impl._exec(
  390. self.context._version.delete().where(
  391. self.context._version.c.version_num ==
  392. literal_column("'%s'" % version)))
  393. if not self.context.as_sql and ret.rowcount != 1:
  394. raise util.CommandError(
  395. "Online migration expected to match one "
  396. "row when deleting '%s' in '%s'; "
  397. "%d found"
  398. % (version,
  399. self.context.version_table, ret.rowcount))
  400. def _update_version(self, from_, to_):
  401. assert to_ not in self.heads
  402. self.heads.remove(from_)
  403. self.heads.add(to_)
  404. ret = self.context.impl._exec(
  405. self.context._version.update().
  406. values(version_num=literal_column("'%s'" % to_)).where(
  407. self.context._version.c.version_num
  408. == literal_column("'%s'" % from_))
  409. )
  410. if not self.context.as_sql and ret.rowcount != 1:
  411. raise util.CommandError(
  412. "Online migration expected to match one "
  413. "row when updating '%s' to '%s' in '%s'; "
  414. "%d found"
  415. % (from_, to_, self.context.version_table, ret.rowcount))
  416. def update_to_step(self, step):
  417. if step.should_delete_branch(self.heads):
  418. vers = step.delete_version_num
  419. log.debug("branch delete %s", vers)
  420. self._delete_version(vers)
  421. elif step.should_create_branch(self.heads):
  422. vers = step.insert_version_num
  423. log.debug("new branch insert %s", vers)
  424. self._insert_version(vers)
  425. elif step.should_merge_branches(self.heads):
  426. # delete revs, update from rev, update to rev
  427. (delete_revs, update_from_rev,
  428. update_to_rev) = step.merge_branch_idents(self.heads)
  429. log.debug(
  430. "merge, delete %s, update %s to %s",
  431. delete_revs, update_from_rev, update_to_rev)
  432. for delrev in delete_revs:
  433. self._delete_version(delrev)
  434. self._update_version(update_from_rev, update_to_rev)
  435. elif step.should_unmerge_branches(self.heads):
  436. (update_from_rev, update_to_rev,
  437. insert_revs) = step.unmerge_branch_idents(self.heads)
  438. log.debug(
  439. "unmerge, insert %s, update %s to %s",
  440. insert_revs, update_from_rev, update_to_rev)
  441. for insrev in insert_revs:
  442. self._insert_version(insrev)
  443. self._update_version(update_from_rev, update_to_rev)
  444. else:
  445. from_, to_ = step.update_version_num(self.heads)
  446. log.debug("update %s to %s", from_, to_)
  447. self._update_version(from_, to_)
  448. class MigrationStep(object):
  449. @property
  450. def name(self):
  451. return self.migration_fn.__name__
  452. @classmethod
  453. def upgrade_from_script(cls, revision_map, script):
  454. return RevisionStep(revision_map, script, True)
  455. @classmethod
  456. def downgrade_from_script(cls, revision_map, script):
  457. return RevisionStep(revision_map, script, False)
  458. @property
  459. def is_downgrade(self):
  460. return not self.is_upgrade
  461. @property
  462. def short_log(self):
  463. return "%s %s -> %s" % (
  464. self.name,
  465. util.format_as_comma(self.from_revisions_no_deps),
  466. util.format_as_comma(self.to_revisions_no_deps)
  467. )
  468. def __str__(self):
  469. if self.doc:
  470. return "%s %s -> %s, %s" % (
  471. self.name,
  472. util.format_as_comma(self.from_revisions_no_deps),
  473. util.format_as_comma(self.to_revisions_no_deps),
  474. self.doc
  475. )
  476. else:
  477. return self.short_log
  478. class RevisionStep(MigrationStep):
  479. def __init__(self, revision_map, revision, is_upgrade):
  480. self.revision_map = revision_map
  481. self.revision = revision
  482. self.is_upgrade = is_upgrade
  483. if is_upgrade:
  484. self.migration_fn = revision.module.upgrade
  485. else:
  486. self.migration_fn = revision.module.downgrade
  487. def __repr__(self):
  488. return "RevisionStep(%r, is_upgrade=%r)" % (
  489. self.revision.revision, self.is_upgrade
  490. )
  491. def __eq__(self, other):
  492. return isinstance(other, RevisionStep) and \
  493. other.revision == self.revision and \
  494. self.is_upgrade == other.is_upgrade
  495. @property
  496. def doc(self):
  497. return self.revision.doc
  498. @property
  499. def from_revisions(self):
  500. if self.is_upgrade:
  501. return self.revision._all_down_revisions
  502. else:
  503. return (self.revision.revision, )
  504. @property
  505. def from_revisions_no_deps(self):
  506. if self.is_upgrade:
  507. return self.revision._versioned_down_revisions
  508. else:
  509. return (self.revision.revision, )
  510. @property
  511. def to_revisions(self):
  512. if self.is_upgrade:
  513. return (self.revision.revision, )
  514. else:
  515. return self.revision._all_down_revisions
  516. @property
  517. def to_revisions_no_deps(self):
  518. if self.is_upgrade:
  519. return (self.revision.revision, )
  520. else:
  521. return self.revision._versioned_down_revisions
  522. @property
  523. def _has_scalar_down_revision(self):
  524. return len(self.revision._all_down_revisions) == 1
  525. def should_delete_branch(self, heads):
  526. """A delete is when we are a. in a downgrade and b.
  527. we are going to the "base" or we are going to a version that
  528. is implied as a dependency on another version that is remaining.
  529. """
  530. if not self.is_downgrade:
  531. return False
  532. if self.revision.revision not in heads:
  533. return False
  534. downrevs = self.revision._all_down_revisions
  535. if not downrevs:
  536. # is a base
  537. return True
  538. else:
  539. # determine what the ultimate "to_revisions" for an
  540. # unmerge would be. If there are none, then we're a delete.
  541. to_revisions = self._unmerge_to_revisions(heads)
  542. return not to_revisions
  543. def merge_branch_idents(self, heads):
  544. other_heads = set(heads).difference(self.from_revisions)
  545. if other_heads:
  546. ancestors = set(
  547. r.revision for r in
  548. self.revision_map._get_ancestor_nodes(
  549. self.revision_map.get_revisions(other_heads),
  550. check=False
  551. )
  552. )
  553. from_revisions = list(
  554. set(self.from_revisions).difference(ancestors))
  555. else:
  556. from_revisions = list(self.from_revisions)
  557. return (
  558. # delete revs, update from rev, update to rev
  559. list(from_revisions[0:-1]), from_revisions[-1],
  560. self.to_revisions[0]
  561. )
  562. def _unmerge_to_revisions(self, heads):
  563. other_heads = set(heads).difference([self.revision.revision])
  564. if other_heads:
  565. ancestors = set(
  566. r.revision for r in
  567. self.revision_map._get_ancestor_nodes(
  568. self.revision_map.get_revisions(other_heads),
  569. check=False
  570. )
  571. )
  572. return list(set(self.to_revisions).difference(ancestors))
  573. else:
  574. return self.to_revisions
  575. def unmerge_branch_idents(self, heads):
  576. to_revisions = self._unmerge_to_revisions(heads)
  577. return (
  578. # update from rev, update to rev, insert revs
  579. self.from_revisions[0], to_revisions[-1],
  580. to_revisions[0:-1]
  581. )
  582. def should_create_branch(self, heads):
  583. if not self.is_upgrade:
  584. return False
  585. downrevs = self.revision._all_down_revisions
  586. if not downrevs:
  587. # is a base
  588. return True
  589. else:
  590. # none of our downrevs are present, so...
  591. # we have to insert our version. This is true whether
  592. # or not there is only one downrev, or multiple (in the latter
  593. # case, we're a merge point.)
  594. if not heads.intersection(downrevs):
  595. return True
  596. else:
  597. return False
  598. def should_merge_branches(self, heads):
  599. if not self.is_upgrade:
  600. return False
  601. downrevs = self.revision._all_down_revisions
  602. if len(downrevs) > 1 and \
  603. len(heads.intersection(downrevs)) > 1:
  604. return True
  605. return False
  606. def should_unmerge_branches(self, heads):
  607. if not self.is_downgrade:
  608. return False
  609. downrevs = self.revision._all_down_revisions
  610. if self.revision.revision in heads and len(downrevs) > 1:
  611. return True
  612. return False
  613. def update_version_num(self, heads):
  614. if not self._has_scalar_down_revision:
  615. downrev = heads.intersection(self.revision._all_down_revisions)
  616. assert len(downrev) == 1, \
  617. "Can't do an UPDATE because downrevision is ambiguous"
  618. down_revision = list(downrev)[0]
  619. else:
  620. down_revision = self.revision._all_down_revisions[0]
  621. if self.is_upgrade:
  622. return down_revision, self.revision.revision
  623. else:
  624. return self.revision.revision, down_revision
  625. @property
  626. def delete_version_num(self):
  627. return self.revision.revision
  628. @property
  629. def insert_version_num(self):
  630. return self.revision.revision
  631. class StampStep(MigrationStep):
  632. def __init__(self, from_, to_, is_upgrade, branch_move):
  633. self.from_ = util.to_tuple(from_, default=())
  634. self.to_ = util.to_tuple(to_, default=())
  635. self.is_upgrade = is_upgrade
  636. self.branch_move = branch_move
  637. self.migration_fn = self.stamp_revision
  638. doc = None
  639. def stamp_revision(self, **kw):
  640. return None
  641. def __eq__(self, other):
  642. return isinstance(other, StampStep) and \
  643. other.from_revisions == self.revisions and \
  644. other.to_revisions == self.to_revisions and \
  645. other.branch_move == self.branch_move and \
  646. self.is_upgrade == other.is_upgrade
  647. @property
  648. def from_revisions(self):
  649. return self.from_
  650. @property
  651. def to_revisions(self):
  652. return self.to_
  653. @property
  654. def from_revisions_no_deps(self):
  655. return self.from_
  656. @property
  657. def to_revisions_no_deps(self):
  658. return self.to_
  659. @property
  660. def delete_version_num(self):
  661. assert len(self.from_) == 1
  662. return self.from_[0]
  663. @property
  664. def insert_version_num(self):
  665. assert len(self.to_) == 1
  666. return self.to_[0]
  667. def update_version_num(self, heads):
  668. assert len(self.from_) == 1
  669. assert len(self.to_) == 1
  670. return self.from_[0], self.to_[0]
  671. def merge_branch_idents(self, heads):
  672. return (
  673. # delete revs, update from rev, update to rev
  674. list(self.from_[0:-1]), self.from_[-1],
  675. self.to_[0]
  676. )
  677. def unmerge_branch_idents(self, heads):
  678. return (
  679. # update from rev, update to rev, insert revs
  680. self.from_[0], self.to_[-1],
  681. list(self.to_[0:-1])
  682. )
  683. def should_delete_branch(self, heads):
  684. return self.is_downgrade and self.branch_move
  685. def should_create_branch(self, heads):
  686. return self.is_upgrade and self.branch_move
  687. def should_merge_branches(self, heads):
  688. return len(self.from_) > 1
  689. def should_unmerge_branches(self, heads):
  690. return len(self.to_) > 1